Skip to content

Commit 309128b

Browse files
[JAX SC] Pass table name and max vocab ID to SparseCore input batch.
This change modifies `RaggedTensorInputBatch` to accept a table name and maximum vocabulary ID. These values are then passed to the `SparseCsrInputBatchStream` for potential use in error checking or logging during SparseCore embedding processing. PiperOrigin-RevId: 821096546
1 parent 0278e82 commit 309128b

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ cc_library(
308308
":sparse_csr_input_stream",
309309
":unity_weights_stream",
310310
"@com_google_absl//absl/log:check",
311+
"@com_google_absl//absl/strings:string_view",
311312
],
312313
)
313314

jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_RAGGED_TENSOR_INPUT_BATCH_H_
1515
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_RAGGED_TENSOR_INPUT_BATCH_H_
1616
#include <cstdint>
17+
#include <limits>
1718

1819
#include "absl/log/check.h" // from @com_google_absl
20+
#include "absl/strings/string_view.h" // from @com_google_absl
1921
#include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h"
2022
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
2123
#include "jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h"
@@ -63,16 +65,21 @@ class RaggedTensorInputBatch : public AbstractInputBatch {
6365
// This class represents a batch of input data encoded using row offsets,
6466
// similar to how RaggedTensor uses row offsets as described in
6567
// https://www.tensorflow.org/guide/ragged_tensor#tfraggedtensorfrom_row_splits.
66-
RaggedTensorInputBatch(EmbeddingIdsView embedding_ids,
67-
RowOffsetsView row_offsets)
68-
: embedding_ids_(embedding_ids), row_offsets_(row_offsets) {}
68+
RaggedTensorInputBatch(
69+
EmbeddingIdsView embedding_ids, RowOffsetsView row_offsets,
70+
absl::string_view table_name = "unknown_table_name",
71+
int64_t max_vocab_id = std::numeric_limits<int64_t>::max())
72+
: embedding_ids_(embedding_ids),
73+
row_offsets_(row_offsets),
74+
table_name_(table_name),
75+
max_vocab_id_(max_vocab_id) {}
6976

7077
int64_t size() const override { return row_offsets_.size() - 1; }
7178
void ExtractCooTensors(const ExtractCooTensorsOptions& options,
7279
ExtractedCooTensors& coo_tensors) override {
7380
SparseCsrInputBatchStream<int64_t, EmbeddingIdsView, RowOffsetsView>
7481
values_stream(embedding_ids_, row_offsets_, options.slice_start,
75-
options.slice_end);
82+
options.slice_end, table_name_, max_vocab_id_);
7683
UnityWeightsStream weights_stream(values_stream);
7784

7885
ProcessCooTensors(options, values_stream, weights_stream, coo_tensors);
@@ -81,6 +88,8 @@ class RaggedTensorInputBatch : public AbstractInputBatch {
8188
private:
8289
EmbeddingIdsView embedding_ids_;
8390
RowOffsetsView row_offsets_;
91+
absl::string_view table_name_;
92+
int64_t max_vocab_id_;
8493
};
8594

8695
// deduction guide for compiler

0 commit comments

Comments
 (0)