Skip to content

Commit 45bddee

Browse files
Add batch number to TraceMe annotations in input preprocessing.
This change enhances profiling by including the `batch_number` as metadata in various `tsl::profiler::TraceMe` calls within the input preprocessing pipeline. This allows for better analysis of performance across different batches. PiperOrigin-RevId: 832467473
1 parent c1f92d7 commit 45bddee

File tree

1 file changed

+57
-32
lines changed

1 file changed

+57
-32
lines changed

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -164,29 +164,31 @@ void ExtractSortAndGroupCooTensorsForTable(
164164
const PreprocessSparseDenseMatmulInputOptions& options,
165165
absl::BlockingCounter& counter) {
166166
tsl::profiler::TraceMe traceme([&] {
167-
return absl::StrCat("InputPreprocessingTable-ExtractSortGroup-",
168-
state.stacked_table_name);
167+
return tsl::profiler::TraceMeEncode(
168+
absl::StrCat("InputPreprocessingTable-ExtractSortGroup-",
169+
state.stacked_table_name),
170+
{{"batch_number", options.batch_number}});
169171
});
170172
for (int local_device = 0; local_device < options.local_device_count;
171173
++local_device) {
172-
PreprocessingThreadPool()->Schedule([&, local_device, &state = state,
173-
input_batches] {
174-
state.extracted_coo_tensors_per_device[local_device] =
175-
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
176-
state.stacked_table_metadata, input_batches, local_device,
177-
options);
178-
179-
internal::StatsPerDevice stats_per_device =
180-
state.stats_per_host.GetStatsPerDevice(local_device);
181-
state.partitioned_coo_tensors_per_device[local_device] =
182-
SortAndGroupCooTensorsPerLocalDevice(
183-
state.extracted_coo_tensors_per_device[local_device],
184-
state.stacked_table_metadata[0], options, stats_per_device,
185-
state.table_minibatching_required);
186-
state.dropped_id_count_per_device[local_device] =
187-
stats_per_device.dropped_id_count;
188-
counter.DecrementCount();
189-
});
174+
PreprocessingThreadPool()->Schedule(
175+
[&, local_device, &state = state, input_batches] {
176+
state.extracted_coo_tensors_per_device[local_device] =
177+
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
178+
state.stacked_table_metadata, input_batches, local_device,
179+
options);
180+
181+
internal::StatsPerDevice stats_per_device =
182+
state.stats_per_host.GetStatsPerDevice(local_device);
183+
state.partitioned_coo_tensors_per_device[local_device] =
184+
SortAndGroupCooTensorsPerLocalDevice(
185+
state.extracted_coo_tensors_per_device[local_device],
186+
state.stacked_table_metadata[0], options, stats_per_device,
187+
state.table_minibatching_required);
188+
state.dropped_id_count_per_device[local_device] =
189+
stats_per_device.dropped_id_count;
190+
counter.DecrementCount();
191+
});
190192
}
191193
}
192194

@@ -210,8 +212,10 @@ void CreateMinibatchingBucketsForTable(
210212
TableState& state, const PreprocessSparseDenseMatmulInputOptions& options,
211213
absl::BlockingCounter& counter) {
212214
tsl::profiler::TraceMe traceme([&] {
213-
return absl::StrCat("InputPreprocessingTable-CreateMinibatchingBuckets-",
214-
state.stacked_table_name);
215+
return tsl::profiler::TraceMeEncode(
216+
absl::StrCat("InputPreprocessingTable-CreateMinibatchingBuckets-",
217+
state.stacked_table_name),
218+
{{"batch_number", options.batch_number}});
215219
});
216220
state.stats_per_host.dropped_id_count = 0;
217221
for (int local_device = 0; local_device < options.local_device_count;
@@ -369,7 +373,10 @@ void MergeStats(
369373
absl::StatusOr<bool> SyncMinibatchingRequired(
370374
const PreprocessSparseDenseMatmulInputOptions& options,
371375
absl::Span<const TableState> table_states) {
372-
tsl::profiler::TraceMe traceme("SyncMinibatchingRequired");
376+
tsl::profiler::TraceMe traceme([&] {
377+
return tsl::profiler::TraceMeEncode(
378+
"SyncMinibatchingRequired", {{"batch_number", options.batch_number}});
379+
});
373380
if (!options.enable_minibatching) {
374381
return false;
375382
}
@@ -395,7 +402,10 @@ absl::StatusOr<bool> SyncMinibatchingRequired(
395402
absl::StatusOr<MinibatchingSplit> SyncMinibatchingSplit(
396403
const PreprocessSparseDenseMatmulInputOptions& options,
397404
absl::Span<const TableState> table_states) {
398-
tsl::profiler::TraceMe traceme("SyncMinibatchingSplit");
405+
tsl::profiler::TraceMe traceme([&] {
406+
return tsl::profiler::TraceMeEncode(
407+
"SyncMinibatchingSplit", {{"batch_number", options.batch_number}});
408+
});
399409
MinibatchingSplit local_minibatching_split = 0;
400410
for (const auto& state : table_states) {
401411
local_minibatching_split |= state.table_minibatching_split;
@@ -454,8 +464,10 @@ void FillDeviceBuffersForTable(
454464
MinibatchingSplit global_minibatching_split,
455465
absl::BlockingCounter& counter) {
456466
tsl::profiler::TraceMe traceme([&] {
457-
return absl::StrCat("InputPreprocessingTable-FillBuffer-",
458-
state.stacked_table_name);
467+
return tsl::profiler::TraceMeEncode(
468+
absl::StrCat("InputPreprocessingTable-FillBuffer-",
469+
state.stacked_table_name),
470+
{{"batch_number", options.batch_number}});
459471
});
460472
for (int local_device = 0; local_device < options.local_device_count;
461473
++local_device) {
@@ -503,9 +515,11 @@ PreprocessSparseDenseMatmulInput(
503515
const absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>&
504516
stacked_tables,
505517
const PreprocessSparseDenseMatmulInputOptions& options) {
506-
tsl::profiler::TraceMe traceme([=, &options] {
507-
return absl::StrCat("input_preprocessing_cc-", options.local_device_count,
508-
"/", options.global_device_count);
518+
tsl::profiler::TraceMe traceme([&] {
519+
return tsl::profiler::TraceMeEncode(
520+
absl::StrCat("input_preprocessing_cc-", options.local_device_count, "/",
521+
options.global_device_count),
522+
{{"batch_number", options.batch_number}});
509523
});
510524
if (options.sharding_strategy != ShardingStrategy::kMod) {
511525
LOG(FATAL) << "Only mod sharding is supported for now.";
@@ -530,7 +544,11 @@ PreprocessSparseDenseMatmulInput(
530544

531545
// Stage 1: COO Extraction and Initial Sort/Group
532546
{
533-
tsl::profiler::TraceMe traceme("ExtractSortAndGroupCooTensors");
547+
tsl::profiler::TraceMe traceme([&] {
548+
return tsl::profiler::TraceMeEncode(
549+
"ExtractSortAndGroupCooTensors",
550+
{{"batch_number", options.batch_number}});
551+
});
534552
absl::BlockingCounter counter(table_states.size() *
535553
options.local_device_count);
536554
for (auto& state : table_states) {
@@ -552,7 +570,11 @@ PreprocessSparseDenseMatmulInput(
552570

553571
if (options.enable_minibatching && global_minibatching_required) {
554572
{
555-
tsl::profiler::TraceMe traceme("CreateMinibatchingBuckets");
573+
tsl::profiler::TraceMe traceme([&] {
574+
return tsl::profiler::TraceMeEncode(
575+
"CreateMinibatchingBuckets",
576+
{{"batch_number", options.batch_number}});
577+
});
556578
absl::BlockingCounter counter(table_states.size() *
557579
options.local_device_count);
558580
for (auto& state : table_states) {
@@ -570,7 +592,10 @@ PreprocessSparseDenseMatmulInput(
570592

571593
// Stage 3: Fill Device Buffers
572594
{
573-
tsl::profiler::TraceMe traceme("FillDeviceBuffers");
595+
tsl::profiler::TraceMe traceme([&] {
596+
return tsl::profiler::TraceMeEncode(
597+
"FillDeviceBuffers", {{"batch_number", options.batch_number}});
598+
});
574599
absl::BlockingCounter counter(table_states.size() *
575600
options.local_device_count);
576601
for (auto& state : table_states) {

0 commit comments

Comments
 (0)