Skip to content

Commit c1f92d7

Browse files
[JAX SC] Refactor _get_stack_table_names into helper functions.
This change breaks down the logic of `_get_stack_table_names` into three new private functions: `_group_tables_by_specs`, `_calculate_activation_memory_metrics`, and `_split_groups_by_memory_limit`. Dedicated unit tests are added for `_split_groups_by_memory_limit`. PiperOrigin-RevId: 832388245
1 parent 5752890 commit c1f92d7

File tree

2 files changed

+133
-19
lines changed

2 files changed

+133
-19
lines changed

jax_tpu_embedding/sparsecore/lib/nn/table_stacking.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,11 @@ def round_up_dim_and_vocab_size(
373373
return table_to_padded_dim, table_to_padded_vocab_size
374374

375375

376-
def _get_stack_table_names(
376+
def _group_tables_for_stacking(
377377
num_sc: int,
378378
flatten_tables: Mapping[str, embedding_spec.TableSpec],
379-
flatten_features: Sequence[embedding_spec.FeatureSpec],
380-
activation_mem_bytes_limit: int,
381-
) -> Sequence[Sequence[str]]:
382-
"""Returns the stack groups for the tables based on their specs."""
383-
original_table_names = set(flatten_tables.keys())
384-
379+
) -> list[list[str]]:
380+
"""Groups table names by padded dimension, optimizer, and combiner."""
385381
table_to_padded_dim, _ = round_up_dim_and_vocab_size(flatten_tables, num_sc)
386382
table_name_map = collections.defaultdict(list)
387383
for table_name, dim in table_to_padded_dim.items():
@@ -391,9 +387,15 @@ def _get_stack_table_names(
391387
flatten_tables[table_name].combiner,
392388
)
393389
table_name_map[key].append(table_name)
390+
return list(table_name_map.values())
394391

395-
groups = list(table_name_map.values())
396392

393+
def _calculate_activation_memory_metrics(
394+
num_sc: int,
395+
flatten_tables: Mapping[str, embedding_spec.TableSpec],
396+
flatten_features: Sequence[embedding_spec.FeatureSpec],
397+
) -> tuple[Mapping[str, int], Mapping[str, int]]:
398+
"""Calculates sample count and activation memory bytes per table."""
397399
# Calculate sample_count per sparsecore for each table.
398400
table_to_sample_count = collections.defaultdict(int)
399401
for feature in flatten_features:
@@ -402,11 +404,52 @@ def _get_stack_table_names(
402404
)
403405

404406
# Calculate and register the activation memory usage of this table.
407+
table_to_padded_dim, _ = round_up_dim_and_vocab_size(flatten_tables, num_sc)
405408
table_to_activation_mem_bytes = {
406409
tname: table_to_padded_dim[tname] * table_to_sample_count[tname] * 4
407410
for tname in flatten_tables.keys()
408411
}
412+
return table_to_sample_count, table_to_activation_mem_bytes
413+
414+
415+
def _get_stack_table_names(
416+
num_sc: int,
417+
flatten_tables: Mapping[str, embedding_spec.TableSpec],
418+
flatten_features: Sequence[embedding_spec.FeatureSpec],
419+
activation_mem_bytes_limit: int,
420+
) -> Sequence[Sequence[str]]:
421+
"""Returns the stack groups for the tables based on their specs."""
422+
original_table_names = set(flatten_tables.keys())
423+
424+
groups = _group_tables_for_stacking(num_sc, flatten_tables)
425+
426+
_, table_to_activation_mem_bytes = _calculate_activation_memory_metrics(
427+
num_sc, flatten_tables, flatten_features
428+
)
429+
430+
validated_groups = _split_groups_by_memory_limit(
431+
groups, table_to_activation_mem_bytes, activation_mem_bytes_limit
432+
)
433+
434+
grouped_table_names = set()
435+
for group in validated_groups:
436+
grouped_table_names.update(group)
437+
438+
if original_table_names != grouped_table_names:
439+
raise ValueError(
440+
"Table names are not grouped correctly. Original table names:"
441+
f" {original_table_names}, grouped table names: {grouped_table_names}"
442+
)
443+
444+
return validated_groups
409445

446+
447+
def _split_groups_by_memory_limit(
448+
groups: list[list[str]],
449+
table_to_activation_mem_bytes: Mapping[str, int],
450+
activation_mem_bytes_limit: int,
451+
) -> list[list[str]]:
452+
"""Splits table groups to respect the activation memory limit."""
410453
validated_groups = []
411454
for group in groups:
412455
# A list of groups that are split from the current group.
@@ -446,17 +489,6 @@ def _get_stack_table_names(
446489
)
447490
# Add into the validated groups.
448491
validated_groups.extend(split_groups)
449-
450-
grouped_table_names = set()
451-
for group in validated_groups:
452-
grouped_table_names.update(group)
453-
454-
if original_table_names != grouped_table_names:
455-
raise ValueError(
456-
"Table names are not grouped correctly. Original table names:"
457-
f" {original_table_names}, grouped table names: {grouped_table_names}"
458-
)
459-
460492
return validated_groups
461493

462494

jax_tpu_embedding/sparsecore/lib/nn/tests/table_stacking_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,88 @@ def get_feature_specs() -> list[embedding_spec.FeatureSpec]:
813813
num_sc_per_device=self.num_sc_per_device,
814814
)
815815

816+
@parameterized.parameters(
817+
dict(
818+
groups=[['a', 'b', 'c']],
819+
table_to_mem={'a': 10, 'b': 20, 'c': 30},
820+
limit=100,
821+
expected=[['a', 'b', 'c']],
822+
),
823+
dict(
824+
groups=[['a', 'b', 'c']],
825+
table_to_mem={'a': 10, 'b': 20, 'c': 30},
826+
limit=30,
827+
expected=[['a', 'b'], ['c']],
828+
),
829+
dict(
830+
groups=[['a', 'b', 'c']],
831+
table_to_mem={'a': 10, 'b': 20, 'c': 30},
832+
limit=35,
833+
expected=[['a', 'b'], ['c']],
834+
),
835+
dict(
836+
groups=[['a', 'b', 'c', 'd']],
837+
table_to_mem={'a': 10, 'b': 10, 'c': 10, 'd': 10},
838+
limit=25,
839+
expected=[['a', 'b'], ['c', 'd']],
840+
),
841+
dict(
842+
groups=[['a'], ['b', 'c']],
843+
table_to_mem={'a': 10, 'b': 20, 'c': 30},
844+
limit=35,
845+
expected=[['a'], ['b'], ['c']],
846+
),
847+
dict(
848+
groups=[['a', 'b'], ['c', 'd']],
849+
table_to_mem={'a': 10, 'b': 20, 'c': 5, 'd': 5},
850+
limit=30,
851+
expected=[['a', 'b'], ['c', 'd']],
852+
),
853+
dict(
854+
groups=[['a', 'b', 'c']],
855+
table_to_mem={'a': 60, 'b': 60, 'c': 60},
856+
limit=50,
857+
expected=[['a'], ['b'], ['c']],
858+
),
859+
dict(
860+
groups=[['a', 'b', 'c']],
861+
table_to_mem={'a': 10, 'b': 20, 'c': 30},
862+
limit=60,
863+
expected=[['a', 'b', 'c']],
864+
),
865+
dict(
866+
groups=[['a', 'b', 'c']],
867+
table_to_mem={'a': 10, 'b': 20, 'c': 30},
868+
limit=40,
869+
expected=[['a', 'b'], ['c']],
870+
),
871+
dict(
872+
groups=[['a', 'b'], ['c']],
873+
table_to_mem={'a': 10, 'b': 20, 'c': 30},
874+
limit=30,
875+
expected=[['a', 'b'], ['c']],
876+
),
877+
dict(
878+
groups=[['a', 'b'], ['c', 'd']],
879+
table_to_mem={'a': 10, 'b': 15, 'c': 20, 'd': 25},
880+
limit=30,
881+
expected=[['a', 'b'], ['c'], ['d']],
882+
),
883+
dict(
884+
groups=[],
885+
table_to_mem={'a': 10},
886+
limit=100,
887+
expected=[],
888+
),
889+
)
890+
def test_split_groups_by_memory_limit(
891+
self, groups, table_to_mem, limit, expected
892+
):
893+
result = table_stacking._split_groups_by_memory_limit(
894+
groups, table_to_mem, limit
895+
)
896+
self.assertEqual(result, expected)
897+
816898
@parameterized.product(
817899
donate=[True, False],
818900
device_count=[1, 2, 4, -1],

0 commit comments

Comments
 (0)