Skip to content

Commit 7e6418e

Browse files
Input preprocessing library to support vocab-dimension mini-batching.
Currently only PMAP is supported for simplicity. JAX support will be added later. PiperOrigin-RevId: 716024250
1 parent b91e774 commit 7e6418e

14 files changed

+2099
-161
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension")
14+
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library")
1515
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
1616
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library")
1717

@@ -72,17 +72,65 @@ cc_test(
7272
],
7373
)
7474

75+
pybind_library(
76+
name = "input_preprocessing_py_util",
77+
srcs = [
78+
"input_preprocessing_py_util.cc",
79+
],
80+
hdrs = [
81+
"input_preprocessing_py_util.h",
82+
],
83+
deps = [
84+
":input_preprocessing_util",
85+
"@com_google_absl//absl/base:core_headers",
86+
"@com_google_absl//absl/container:flat_hash_map",
87+
"@com_google_absl//absl/log:check",
88+
"@tsl//tsl/profiler/lib:traceme",
89+
],
90+
)
91+
7592
pybind_extension(
7693
name = "input_preprocessing_cc",
77-
srcs = ["input_preprocessing.cc"],
94+
srcs = [
95+
"input_preprocessing.cc",
96+
],
7897
deps = [
98+
":input_preprocessing_py_util",
7999
":input_preprocessing_threads",
80100
":input_preprocessing_util",
101+
"@com_google_absl//absl/base:core_headers",
81102
"@com_google_absl//absl/container:flat_hash_map",
103+
"@com_google_absl//absl/log",
82104
"@com_google_absl//absl/log:check",
83105
"@com_google_absl//absl/strings",
84106
"@com_google_absl//absl/synchronization",
85107
"@com_google_absl//absl/types:span",
108+
"@highway//:hwy",
109+
"@highway//hwy/contrib/sort:vqsort",
110+
"@tsl//tsl/profiler/lib:connected_traceme",
111+
"@tsl//tsl/profiler/lib:traceme",
112+
],
113+
)
114+
115+
pybind_extension(
116+
name = "input_preprocessing_with_mini_batching_cc",
117+
srcs = [
118+
"input_preprocessing_with_mini_batching.cc",
119+
"input_preprocessing_with_mini_batching.h",
120+
],
121+
deps = [
122+
":input_preprocessing_py_util",
123+
":input_preprocessing_threads",
124+
":input_preprocessing_util",
125+
"@com_google_absl//absl/base:core_headers",
126+
"@com_google_absl//absl/container:flat_hash_map",
127+
"@com_google_absl//absl/log",
128+
"@com_google_absl//absl/log:check",
129+
"@com_google_absl//absl/strings",
130+
"@com_google_absl//absl/synchronization",
131+
"@com_google_absl//absl/types:span",
132+
"@highway//:hwy",
133+
"@highway//hwy/contrib/sort:vqsort",
86134
"@tsl//tsl/profiler/lib:connected_traceme",
87135
"@tsl//tsl/profiler/lib:traceme",
88136
],

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include <algorithm>
15-
#include <cmath>
1615
#include <optional>
1716
#include <string>
1817
#include <utility>
@@ -24,6 +23,7 @@
2423
#include "absl/strings/string_view.h" // from @com_google_absl
2524
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
2625
#include "absl/types/span.h" // from @com_google_absl
26+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h"
2727
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
2828
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
2929
#include "pybind11/cast.h" // from @pybind11
@@ -148,48 +148,6 @@ int ExtractCooTensors(const py::array& features,
148148
global_device_count, coo_tensors);
149149
}
150150

151-
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
152-
GetStackedTableMetadata(py::list feature_specs, py::list features) {
153-
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
154-
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
155-
stacked_table_metadata;
156-
for (int i = 0; i < feature_specs.size(); ++i) {
157-
const py::object& feature_spec = feature_specs[i];
158-
const py::array& feature = features[i].cast<py::array>();
159-
const py::object& feature_transformation =
160-
feature_spec.attr("_id_transformation");
161-
const py::object& table_spec = feature_spec.attr("table_spec");
162-
const py::object& stacked_table_spec =
163-
table_spec.attr("stacked_table_spec");
164-
const std::string stacked_table_name = py::cast<std::string>(
165-
table_spec.attr("_setting_in_stack").attr("stack_name"));
166-
int col_shift = 0;
167-
int col_offset = 0;
168-
int row_offset = 0;
169-
const int max_ids_per_partition =
170-
py::cast<int>(stacked_table_spec.attr("max_ids_per_partition"));
171-
const int max_unique_ids_per_partition =
172-
py::cast<int>(stacked_table_spec.attr("max_unique_ids_per_partition"));
173-
if (!feature_transformation.is_none()) {
174-
row_offset = py::cast<int>(feature_transformation.attr("row_offset"));
175-
col_shift = py::cast<int>(feature_transformation.attr("col_shift"));
176-
col_offset = py::cast<int>(feature_transformation.attr("col_offset"));
177-
}
178-
stacked_table_metadata[stacked_table_name].emplace_back(
179-
i, max_ids_per_partition, max_unique_ids_per_partition, row_offset,
180-
col_offset, col_shift,
181-
/*batch_size=*/feature.shape(0));
182-
}
183-
// Sort the stacked tables by row_offset.
184-
for (auto& [_, t] : stacked_table_metadata) {
185-
std::sort(t.begin(), t.end(),
186-
[](const StackedTableMetadata& a, const StackedTableMetadata& b) {
187-
return a.row_offset < b.row_offset;
188-
});
189-
}
190-
return stacked_table_metadata;
191-
}
192-
193151
// Preprocess inputs for a single table. Stacked table here refers to a
194152
// a table that has no parent in the table stacking hierarchy. So in the case
195153
// of table stacking, the stacked table is the top level table and in the case
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h"
15+
16+
#include <algorithm>
17+
#include <cmath>
18+
#include <string>
19+
#include <utility>
20+
#include <vector>
21+
22+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
23+
#include "absl/log/check.h" // from @com_google_absl
24+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
25+
#include "pybind11/cast.h" // from @pybind11
26+
#include "pybind11/gil.h" // from @pybind11
27+
#include "pybind11/numpy.h" // from @pybind11
28+
#include "pybind11/pybind11.h" // from @pybind11
29+
#include "pybind11/pytypes.h" // from @pybind11
30+
#include "tsl/profiler/lib/traceme.h" // from @tsl
31+
32+
namespace jax_sc_embedding {
33+
34+
namespace py = ::pybind11;
35+
36+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
37+
GetStackedTableMetadata(const py::list& feature_specs, const int batch_size) {
38+
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
39+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
40+
stacked_table_metadata;
41+
for (int i = 0; i < feature_specs.size(); ++i) {
42+
const py::object& feature_spec = feature_specs[i];
43+
44+
const py::object& feature_transformation =
45+
feature_spec.attr("_id_transformation");
46+
const py::object& table_spec = feature_spec.attr("table_spec");
47+
const py::object& stacked_table_spec =
48+
table_spec.attr("stacked_table_spec");
49+
const std::string stacked_table_name = py::cast<std::string>(
50+
table_spec.attr("_setting_in_stack").attr("stack_name"));
51+
int col_shift = 0;
52+
int col_offset = 0;
53+
int row_offset = 0;
54+
const int max_ids_per_partition =
55+
py::cast<int>(stacked_table_spec.attr("max_ids_per_partition"));
56+
const int max_unique_ids_per_partition =
57+
py::cast<int>(stacked_table_spec.attr("max_unique_ids_per_partition"));
58+
const int vocab_size =
59+
py::cast<int>(stacked_table_spec.attr("stack_vocab_size"));
60+
if (!feature_transformation.is_none()) {
61+
row_offset = py::cast<int>(feature_transformation.attr("row_offset"));
62+
col_shift = py::cast<int>(feature_transformation.attr("col_shift"));
63+
col_offset = py::cast<int>(feature_transformation.attr("col_offset"));
64+
}
65+
stacked_table_metadata[stacked_table_name].emplace_back(
66+
i, max_ids_per_partition, max_unique_ids_per_partition, row_offset,
67+
col_offset, col_shift,
68+
/*batch_size=*/batch_size, vocab_size);
69+
}
70+
// Sort the stacked tables by row_offset.
71+
for (auto& [_, t] : stacked_table_metadata) {
72+
std::sort(t.begin(), t.end(),
73+
[](const StackedTableMetadata& a, const StackedTableMetadata& b) {
74+
return a.row_offset < b.row_offset;
75+
});
76+
}
77+
return stacked_table_metadata;
78+
}
79+
80+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
81+
GetStackedTableMetadata(const py::list& feature_specs,
82+
const py::list& features) {
83+
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
84+
int batch_size = features[0].cast<py::array>().shape(0);
85+
return GetStackedTableMetadata(feature_specs, batch_size);
86+
}
87+
88+
} // namespace jax_sc_embedding
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_
15+
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_
16+
#include <string>
17+
#include <vector>
18+
19+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
20+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
21+
#include "pybind11/numpy.h" // from @pybind11
22+
#include "pybind11/pytypes.h" // from @pybind11
23+
24+
namespace jax_sc_embedding {
25+
26+
namespace py = ::pybind11;
27+
28+
// Copy information from feature_specs to StackedTableMetadata.
29+
// The features argument is only used to get the batch size.
30+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
31+
GetStackedTableMetadata(const py::list& feature_specs,
32+
const py::list& features);
33+
34+
// Copy information from feature_specs to StackedTableMetadata.
35+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
36+
GetStackedTableMetadata(const py::list& feature_specs, int batch_size);
37+
38+
} // namespace jax_sc_embedding
39+
40+
#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct CooFormat {
3535

3636
// Get adjusted col_id based on shift and offset.
3737
int GetColId(int col_id, int col_shift, int col_offset, int num_scs_mod,
38-
int num_scs_mod_inv);
38+
int num_scs_mod_inv);
3939

4040
inline unsigned int CeilOfRatio(unsigned int numerator,
4141
unsigned int denominator) {
@@ -50,14 +50,16 @@ struct StackedTableMetadata {
5050
StackedTableMetadata() = delete;
5151
StackedTableMetadata(int feature_index, int max_ids_per_partition,
5252
int max_unique_ids_per_partition, int row_offset,
53-
int col_offset, int col_shift, int batch_size)
53+
int col_offset, int col_shift, int batch_size,
54+
int stacked_table_vocab_size = 0)
5455
: feature_index(feature_index),
5556
max_ids_per_partition(max_ids_per_partition),
5657
max_unique_ids_per_partition(max_unique_ids_per_partition),
5758
row_offset(row_offset),
5859
col_offset(col_offset),
5960
col_shift(col_shift),
60-
batch_size(batch_size) {}
61+
batch_size(batch_size),
62+
stacked_table_vocab_size(stacked_table_vocab_size) {}
6163
// The batch is given as a list of features (numpy arrays). `feature_index`
6264
// represents the index of the feature in the list.
6365
int feature_index;
@@ -70,6 +72,8 @@ struct StackedTableMetadata {
7072

7173
// Process local batch size of the feature.
7274
int batch_size;
75+
76+
int stacked_table_vocab_size;
7377
};
7478

7579
void SortAndGroupCooTensors(

0 commit comments

Comments
 (0)