Skip to content

Commit eb20789

Browse files
Ensure that suggested COO buffer size does not exceed the theoretical max based on limits.
PiperOrigin-RevId: 816192724
1 parent 58695a0 commit eb20789

File tree

7 files changed

+35
-18
lines changed

7 files changed

+35
-18
lines changed

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ struct TableState {
134134
stacked_table_metadata(metadata),
135135
coo_buffer_size_per_device(
136136
ComputeCooBufferSizePerDevice(num_scs, options.num_sc_per_device,
137-
metadata, options.batch_number)),
137+
metadata, options.batch_number,
138+
options.enable_minibatching)),
138139
csr_arrays_per_host(options.local_device_count,
139140
row_pointers_size_per_bucket *
140141
(options.enable_minibatching

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,9 @@ TEST_F(MinibatchingCountTest, SingleHostMinibatchCountIsCorrectWhenRequired) {
713713
// Also increase buffer size.
714714
stacked_table_metadata_[0].max_ids_per_partition = 5;
715715
stacked_table_metadata_[0].max_unique_ids_per_partition = 2;
716-
stacked_table_metadata_[0].suggested_coo_buffer_size_per_device = 2048;
717716

718717
stacked_table_metadata_[1].max_ids_per_partition = 5;
719718
stacked_table_metadata_[1].max_unique_ids_per_partition = 6;
720-
stacked_table_metadata_[1].suggested_coo_buffer_size_per_device = 2048;
721719

722720
auto input_batches =
723721
CreateInputBatches(/*max_ids_per_partitions=*/{10, 20},
@@ -799,8 +797,6 @@ TEST_F(MinibatchingCountTest, MultiHostMinibatchCountIsCorrectWhenRequired) {
799797
absl::Mutex mutex;
800798
std::vector<int> minibatches_per_host(kHosts, -1);
801799

802-
stacked_table_metadata_[0].suggested_coo_buffer_size_per_device = 8192;
803-
stacked_table_metadata_[1].suggested_coo_buffer_size_per_device = 8192;
804800
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
805801
stacked_tables({{"table_0", stacked_table_metadata_}});
806802

@@ -850,8 +846,6 @@ TEST_F(MinibatchingCountTest, MultiHostMinibatchCountIsCorrectWhenOneRequires) {
850846
absl::Mutex mutex;
851847
std::vector<int> minibatches_per_host(kHosts, -1);
852848

853-
stacked_table_metadata_[0].suggested_coo_buffer_size_per_device = 8192;
854-
stacked_table_metadata_[1].suggested_coo_buffer_size_per_device = 8192;
855849
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
856850
stacked_tables({{"table_0", stacked_table_metadata_}});
857851

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,29 @@ RowCombiner GetRowCombiner(absl::string_view combiner) {
210210
return RowCombiner::kSum;
211211
}
212212

213+
int64_t MayBeUpdateBufferSize(
214+
int64_t theoretical_max,
215+
std::optional<int64_t> suggested_coo_buffer_size_per_device,
216+
int num_scs_per_device, absl::string_view stacked_table_name) {
217+
// Since the suggested size corresponds to only current device (local SCs),
218+
// Buffer for each SC should be properly aligned, hence ALIGNMENT *
219+
// num_scs_per_device
220+
int64_t suggested_value = RoundUpTo<int64_t>(
221+
suggested_coo_buffer_size_per_device.value(),
222+
TPU_VECTOR_REGISTER_ALIGMENT_SIZE * num_scs_per_device);
223+
CHECK(suggested_value <= theoretical_max)
224+
<< "Suggested Coo Buffer Size is larger than the theoretical "
225+
"max for table "
226+
<< stacked_table_name << ": " << suggested_value << " vs "
227+
<< theoretical_max
228+
<< ". Adjust the suggested size or the max_ids_per_partition values.";
229+
return suggested_value;
230+
}
231+
213232
int ComputeCooBufferSizePerDevice(
214233
const int num_scs, const int num_scs_per_device,
215234
absl::Span<const StackedTableMetadata> stacked_table_metadata,
216-
const int batch_number) {
235+
const int batch_number, bool use_minibatching) {
217236
const int max_ids_per_partition =
218237
MaxIdsPerPartitionForStackedTables(stacked_table_metadata);
219238
const std::optional<int> suggested_coo_buffer_size_per_device =
@@ -222,7 +241,8 @@ int ComputeCooBufferSizePerDevice(
222241
const int64_t max_ids_rounded_up = RoundUpTo<int64_t>(
223242
max_ids_per_partition, TPU_VECTOR_REGISTER_ALIGMENT_SIZE);
224243
const int64_t theoretical_max =
225-
max_ids_rounded_up * num_scs_per_device * num_scs;
244+
max_ids_rounded_up * num_scs_per_device * num_scs *
245+
(use_minibatching ? CooFormat::kMaxMinibatchingBuckets : 1);
226246
const std::string& stacked_table_name = stacked_table_metadata[0].name;
227247
LOG_IF(INFO, batch_number % 100 == 0)
228248
<< "Theoretical Max for table " << stacked_table_name << ": "
@@ -234,12 +254,9 @@ int ComputeCooBufferSizePerDevice(
234254
LOG_IF(INFO, batch_number % 100 == 0)
235255
<< "Suggested Coo Buffer Size for table " << stacked_table_name << ": "
236256
<< suggested_coo_buffer_size_per_device.value();
237-
// Since the suggested size corresponds to only current device (local SCs),
238-
// Buffer for each SC should be properly aligned, hence ALIGNMENT *
239-
// num_scs_per_device
240-
result = RoundUpTo<int64_t>(
241-
suggested_coo_buffer_size_per_device.value(),
242-
TPU_VECTOR_REGISTER_ALIGMENT_SIZE * num_scs_per_device);
257+
result = MayBeUpdateBufferSize(
258+
theoretical_max, suggested_coo_buffer_size_per_device,
259+
num_scs_per_device, stacked_table_name);
243260
} else {
244261
LOG_IF(WARNING, batch_number % 100 == 0)
245262
<< "No Coo Buffer Size provided for table " << stacked_table_name

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ struct StackedTableMetadata {
293293
int ComputeCooBufferSizePerDevice(
294294
int num_scs, int num_scs_per_device,
295295
absl::Span<const StackedTableMetadata> stacked_table_metadata,
296-
int batch_number = 0);
296+
int batch_number = 0, bool use_minibatching = false);
297297

298298
int MaxIdsPerPartitionForStackedTables(
299299
absl::Span<const StackedTableMetadata> stacked_table_metadata);

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ TEST(InputPreprocessingUtilTest, ComputeCooBufferSize) {
139139
/*num_scs_per_device=*/4,
140140
stacked_table_metadata),
141141
96);
142+
stacked_table_metadata[0].suggested_coo_buffer_size_per_device = 1024;
143+
// The theoretical max is 16 * 4 * 4 = 256. This is less than the suggestion.
144+
EXPECT_DEATH(ComputeCooBufferSizePerDevice(/*num_scs=*/4,
145+
/*num_scs_per_device=*/4,
146+
stacked_table_metadata),
147+
".*Check failed: suggested_value <= theoretical_max.*");
142148
}
143149

144150
TEST(SortAndGroupTest, Base) {

jax_tpu_embedding/sparsecore/lib/nn/tests/minibatching_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def setUp(self):
137137
optimizer=embedding_spec.SGDOptimizerSpec(),
138138
combiner="sum",
139139
name="table_a",
140-
suggested_coo_buffer_size_per_device=8192,
141140
)
142141
self.feature_spec = embedding_spec.FeatureSpec(
143142
table_spec=self.table_spec,
@@ -404,7 +403,6 @@ def setUp(self):
404403
optimizer=embedding_spec.SGDOptimizerSpec(),
405404
combiner="sum",
406405
name="table_a",
407-
suggested_coo_buffer_size_per_device=16384,
408406
)
409407
self.feature_spec = embedding_spec.FeatureSpec(
410408
table_spec=self.table_spec,

jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def generate_feature_specs(num_features: int, num_samples: int):
8181
total_sample_count=num_samples,
8282
max_ids_per_partition=1024,
8383
max_unique_ids_per_partition=1024,
84+
suggested_coo_buffer_size_per_device=4096,
8485
),
8586
)
8687
feature_spec = embedding_spec.FeatureSpec(

0 commit comments

Comments
 (0)