|
16 | 16 | import collections |
17 | 17 | import functools |
18 | 18 | import hashlib |
19 | | -from typing import Callable, Dict, Mapping, Sequence, TypeAlias, TypeVar |
| 19 | +from typing import Callable, Mapping, Sequence, TypeAlias, TypeVar |
20 | 20 |
|
21 | 21 | from absl import logging |
22 | 22 | import jax |
@@ -588,7 +588,7 @@ def _unstack_and_unshard_stacked_table( |
588 | 588 | stacked_table: jax.Array, |
589 | 589 | stacked_table_specs: embedding_spec_pb2.StackedTableSpecProto, |
590 | 590 | donate: bool = False, |
591 | | -) -> Dict[str, jax.Array]: |
| 591 | +) -> Mapping[str, jax.Array]: |
592 | 592 | """Unstack and unshard the stacked table.""" |
593 | 593 |
|
594 | 594 | stacked_table_sharding = stacked_table.sharding |
@@ -696,10 +696,10 @@ def _unstack_and_unshard( |
696 | 696 |
|
697 | 697 |
|
698 | 698 | def unstack_and_unshard_stacked_tables( |
699 | | - stacked_tables: Dict[str, jax.Array], |
| 699 | + stacked_tables: Mapping[str, jax.Array], |
700 | 700 | embedding_specs: embedding_spec_pb2.EmbeddingSpecProto, |
701 | 701 | donate: bool = False, |
702 | | -) -> Dict[str, jax.Array]: |
| 702 | +) -> Mapping[str, jax.Array]: |
703 | 703 | """Unstack and unshard the stacked tables. |
704 | 704 |
|
705 | 705 | Args: |
@@ -729,7 +729,7 @@ def unstack_and_unshard_stacked_tables( |
729 | 729 |
|
730 | 730 |
|
731 | 731 | def _stack_and_shard_feature_table( |
732 | | - feature_tables: Dict[str, jax.Array], |
| 732 | + feature_tables: Mapping[str, jax.Array], |
733 | 733 | stacked_table_specs: embedding_spec_pb2.StackedTableSpecProto, |
734 | 734 | delete_input: bool = False, |
735 | 735 | ) -> jax.Array: |
@@ -816,10 +816,10 @@ def mod_shard( |
816 | 816 |
|
817 | 817 |
|
818 | 818 | def stack_and_shard_feature_tables( |
819 | | - feature_tables: Dict[str, jax.Array], |
| 819 | + feature_tables: Mapping[str, jax.Array], |
820 | 820 | embedding_specs: embedding_spec_pb2.EmbeddingSpecProto, |
821 | 821 | delete_input: bool = False, |
822 | | -) -> Dict[str, jax.Array]: |
| 822 | +) -> Mapping[str, jax.Array]: |
823 | 823 | """Stack and shard the feature tables and return the stacked tables. |
824 | 824 |
|
825 | 825 | This function can be run on both TPU or CPU backends. The stacked tables will |
|
0 commit comments