@@ -164,29 +164,31 @@ void ExtractSortAndGroupCooTensorsForTable(
164164 const PreprocessSparseDenseMatmulInputOptions& options,
165165 absl::BlockingCounter& counter) {
166166 tsl::profiler::TraceMe traceme ([&] {
167- return absl::StrCat (" InputPreprocessingTable-ExtractSortGroup-" ,
168- state.stacked_table_name );
167+ return tsl::profiler::TraceMeEncode (
168+ absl::StrCat (" InputPreprocessingTable-ExtractSortGroup-" ,
169+ state.stacked_table_name ),
170+ {{" batch_number" , options.batch_number }});
169171 });
170172 for (int local_device = 0 ; local_device < options.local_device_count ;
171173 ++local_device) {
172- PreprocessingThreadPool ()->Schedule ([&, local_device, &state = state,
173- input_batches] {
174- state.extracted_coo_tensors_per_device [local_device] =
175- internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
176- state.stacked_table_metadata , input_batches, local_device,
177- options);
178-
179- internal::StatsPerDevice stats_per_device =
180- state.stats_per_host .GetStatsPerDevice (local_device);
181- state.partitioned_coo_tensors_per_device [local_device] =
182- SortAndGroupCooTensorsPerLocalDevice (
183- state.extracted_coo_tensors_per_device [local_device],
184- state.stacked_table_metadata [0 ], options, stats_per_device,
185- state.table_minibatching_required );
186- state.dropped_id_count_per_device [local_device] =
187- stats_per_device.dropped_id_count ;
188- counter.DecrementCount ();
189- });
174+ PreprocessingThreadPool ()->Schedule (
175+ [&, local_device, &state = state, input_batches] {
176+ state.extracted_coo_tensors_per_device [local_device] =
177+ internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
178+ state.stacked_table_metadata , input_batches, local_device,
179+ options);
180+
181+ internal::StatsPerDevice stats_per_device =
182+ state.stats_per_host .GetStatsPerDevice (local_device);
183+ state.partitioned_coo_tensors_per_device [local_device] =
184+ SortAndGroupCooTensorsPerLocalDevice (
185+ state.extracted_coo_tensors_per_device [local_device],
186+ state.stacked_table_metadata [0 ], options, stats_per_device,
187+ state.table_minibatching_required );
188+ state.dropped_id_count_per_device [local_device] =
189+ stats_per_device.dropped_id_count ;
190+ counter.DecrementCount ();
191+ });
190192 }
191193}
192194
@@ -210,8 +212,10 @@ void CreateMinibatchingBucketsForTable(
210212 TableState& state, const PreprocessSparseDenseMatmulInputOptions& options,
211213 absl::BlockingCounter& counter) {
212214 tsl::profiler::TraceMe traceme ([&] {
213- return absl::StrCat (" InputPreprocessingTable-CreateMinibatchingBuckets-" ,
214- state.stacked_table_name );
215+ return tsl::profiler::TraceMeEncode (
216+ absl::StrCat (" InputPreprocessingTable-CreateMinibatchingBuckets-" ,
217+ state.stacked_table_name ),
218+ {{" batch_number" , options.batch_number }});
215219 });
216220 state.stats_per_host .dropped_id_count = 0 ;
217221 for (int local_device = 0 ; local_device < options.local_device_count ;
@@ -369,7 +373,10 @@ void MergeStats(
369373absl::StatusOr<bool > SyncMinibatchingRequired (
370374 const PreprocessSparseDenseMatmulInputOptions& options,
371375 absl::Span<const TableState> table_states) {
372- tsl::profiler::TraceMe traceme (" SyncMinibatchingRequired" );
376+ tsl::profiler::TraceMe traceme ([&] {
377+ return tsl::profiler::TraceMeEncode (
378+ " SyncMinibatchingRequired" , {{" batch_number" , options.batch_number }});
379+ });
373380 if (!options.enable_minibatching ) {
374381 return false ;
375382 }
@@ -395,7 +402,10 @@ absl::StatusOr<bool> SyncMinibatchingRequired(
395402absl::StatusOr<MinibatchingSplit> SyncMinibatchingSplit (
396403 const PreprocessSparseDenseMatmulInputOptions& options,
397404 absl::Span<const TableState> table_states) {
398- tsl::profiler::TraceMe traceme (" SyncMinibatchingSplit" );
405+ tsl::profiler::TraceMe traceme ([&] {
406+ return tsl::profiler::TraceMeEncode (
407+ " SyncMinibatchingSplit" , {{" batch_number" , options.batch_number }});
408+ });
399409 MinibatchingSplit local_minibatching_split = 0 ;
400410 for (const auto & state : table_states) {
401411 local_minibatching_split |= state.table_minibatching_split ;
@@ -454,8 +464,10 @@ void FillDeviceBuffersForTable(
454464 MinibatchingSplit global_minibatching_split,
455465 absl::BlockingCounter& counter) {
456466 tsl::profiler::TraceMe traceme ([&] {
457- return absl::StrCat (" InputPreprocessingTable-FillBuffer-" ,
458- state.stacked_table_name );
467+ return tsl::profiler::TraceMeEncode (
468+ absl::StrCat (" InputPreprocessingTable-FillBuffer-" ,
469+ state.stacked_table_name ),
470+ {{" batch_number" , options.batch_number }});
459471 });
460472 for (int local_device = 0 ; local_device < options.local_device_count ;
461473 ++local_device) {
@@ -503,9 +515,11 @@ PreprocessSparseDenseMatmulInput(
503515 const absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>&
504516 stacked_tables,
505517 const PreprocessSparseDenseMatmulInputOptions& options) {
506- tsl::profiler::TraceMe traceme ([=, &options] {
507- return absl::StrCat (" input_preprocessing_cc-" , options.local_device_count ,
508- " /" , options.global_device_count );
518+ tsl::profiler::TraceMe traceme ([&] {
519+ return tsl::profiler::TraceMeEncode (
520+ absl::StrCat (" input_preprocessing_cc-" , options.local_device_count , " /" ,
521+ options.global_device_count ),
522+ {{" batch_number" , options.batch_number }});
509523 });
510524 if (options.sharding_strategy != ShardingStrategy::kMod ) {
511525 LOG (FATAL) << " Only mod sharding is supported for now." ;
@@ -530,7 +544,11 @@ PreprocessSparseDenseMatmulInput(
530544
531545 // Stage 1: COO Extraction and Initial Sort/Group
532546 {
533- tsl::profiler::TraceMe traceme (" ExtractSortAndGroupCooTensors" );
547+ tsl::profiler::TraceMe traceme ([&] {
548+ return tsl::profiler::TraceMeEncode (
549+ " ExtractSortAndGroupCooTensors" ,
550+ {{" batch_number" , options.batch_number }});
551+ });
534552 absl::BlockingCounter counter (table_states.size () *
535553 options.local_device_count );
536554 for (auto & state : table_states) {
@@ -552,7 +570,11 @@ PreprocessSparseDenseMatmulInput(
552570
553571 if (options.enable_minibatching && global_minibatching_required) {
554572 {
555- tsl::profiler::TraceMe traceme (" CreateMinibatchingBuckets" );
573+ tsl::profiler::TraceMe traceme ([&] {
574+ return tsl::profiler::TraceMeEncode (
575+ " CreateMinibatchingBuckets" ,
576+ {{" batch_number" , options.batch_number }});
577+ });
556578 absl::BlockingCounter counter (table_states.size () *
557579 options.local_device_count );
558580 for (auto & state : table_states) {
@@ -570,7 +592,10 @@ PreprocessSparseDenseMatmulInput(
570592
571593 // Stage 3: Fill Device Buffers
572594 {
573- tsl::profiler::TraceMe traceme (" FillDeviceBuffers" );
595+ tsl::profiler::TraceMe traceme ([&] {
596+ return tsl::profiler::TraceMeEncode (
597+ " FillDeviceBuffers" , {{" batch_number" , options.batch_number }});
598+ });
574599 absl::BlockingCounter counter (table_states.size () *
575600 options.local_device_count );
576601 for (auto & state : table_states) {
0 commit comments