Skip to content

Commit aa0a869

Browse files
[JAX SC] Replace Dict with Mapping.
PiperOrigin-RevId: 814294925
1 parent 58695a0 commit aa0a869

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

jax_tpu_embedding/sparsecore/lib/nn/table_stacking.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import collections
1717
import functools
1818
import hashlib
19-
from typing import Callable, Dict, Mapping, Sequence, TypeAlias, TypeVar
19+
from typing import Callable, Mapping, Sequence, TypeAlias, TypeVar
2020

2121
from absl import logging
2222
import jax
@@ -588,7 +588,7 @@ def _unstack_and_unshard_stacked_table(
588588
stacked_table: jax.Array,
589589
stacked_table_specs: embedding_spec_pb2.StackedTableSpecProto,
590590
donate: bool = False,
591-
) -> Dict[str, jax.Array]:
591+
) -> Mapping[str, jax.Array]:
592592
"""Unstack and unshard the stacked table."""
593593

594594
stacked_table_sharding = stacked_table.sharding
@@ -696,10 +696,10 @@ def _unstack_and_unshard(
696696

697697

698698
def unstack_and_unshard_stacked_tables(
699-
stacked_tables: Dict[str, jax.Array],
699+
stacked_tables: Mapping[str, jax.Array],
700700
embedding_specs: embedding_spec_pb2.EmbeddingSpecProto,
701701
donate: bool = False,
702-
) -> Dict[str, jax.Array]:
702+
) -> Mapping[str, jax.Array]:
703703
"""Unstack and unshard the stacked tables.
704704
705705
Args:
@@ -729,7 +729,7 @@ def unstack_and_unshard_stacked_tables(
729729

730730

731731
def _stack_and_shard_feature_table(
732-
feature_tables: Dict[str, jax.Array],
732+
feature_tables: Mapping[str, jax.Array],
733733
stacked_table_specs: embedding_spec_pb2.StackedTableSpecProto,
734734
delete_input: bool = False,
735735
) -> jax.Array:
@@ -816,10 +816,10 @@ def mod_shard(
816816

817817

818818
def stack_and_shard_feature_tables(
819-
feature_tables: Dict[str, jax.Array],
819+
feature_tables: Mapping[str, jax.Array],
820820
embedding_specs: embedding_spec_pb2.EmbeddingSpecProto,
821821
delete_input: bool = False,
822-
) -> Dict[str, jax.Array]:
822+
) -> Mapping[str, jax.Array]:
823823
"""Stack and shard the feature tables and return the stacked tables.
824824
825825
This function can be run on both TPU or CPU backends. The stacked tables will

0 commit comments

Comments
 (0)