Skip to content

Commit 05b1ce6

Browse files
[JAX SC] Change tree.flatten vs jax.tree.leaves
PiperOrigin-RevId: 816960160
1 parent 338c8fa commit 05b1ce6

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

jax_tpu_embedding/sparsecore/lib/nn/embedding.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _verify_feature_specs(
156156
) -> None:
157157
"""Ensures all the fields in the feature specs are correctly defined."""
158158
visited_feature_names = set()
159-
for feature_spec in tree.flatten(feature_specs):
159+
for feature_spec in jax.tree.leaves(feature_specs):
160160
if feature_spec.name in visited_feature_names:
161161
raise ValueError(f"Feature spec {feature_spec.name} is already defined.")
162162
visited_feature_names.add(feature_spec.name)
@@ -166,7 +166,7 @@ def _verify_feature_specs(
166166
def _verify_table_specs(table_specs: Nested[embedding_spec.TableSpec]) -> None:
167167
"""Ensures all the fields in the table specs are correctly defined."""
168168
visited_table_names = set()
169-
for table_spec in tree.flatten(table_specs):
169+
for table_spec in jax.tree.leaves(table_specs):
170170
if table_spec.name in visited_table_names:
171171
raise ValueError(f"Table spec {table_spec.name} is already defined.")
172172
visited_table_names.add(table_spec.name)
@@ -212,7 +212,7 @@ def get_table_specs(
212212
_verify_feature_specs(feature_specs)
213213
table_specs = {
214214
feature_spec.table_spec.name: feature_spec.table_spec
215-
for feature_spec in tree.flatten(feature_specs)
215+
for feature_spec in jax.tree.leaves(feature_specs)
216216
}
217217
_verify_table_specs(table_specs)
218218
return table_specs
@@ -239,14 +239,14 @@ def get_stacked_table_specs(
239239
_verify_feature_specs(feature_specs)
240240
if any(
241241
not feature_spec.table_spec.is_stacked()
242-
for feature_spec in tree.flatten(feature_specs)
242+
for feature_spec in jax.tree.leaves(feature_specs)
243243
):
244244
raise ValueError(
245245
"embedding.prepare_feature_specs_for_training was not called"
246246
)
247247
stacked_table_specs: list[embedding_spec.StackedTableSpec] = [
248248
feature_spec.table_spec.stacked_table_spec
249-
for feature_spec in tree.flatten(feature_specs)
249+
for feature_spec in jax.tree.leaves(feature_specs)
250250
]
251251
return {
252252
stacked_table_specs.stack_name: stacked_table_specs
@@ -281,7 +281,7 @@ def prepare_feature_specs_for_training(
281281
num_sc_per_device = _get_num_sc_per_device(num_sc_per_device)
282282
not_stacked = [
283283
feature
284-
for feature in tree.flatten(feature_specs)
284+
for feature in jax.tree.leaves(feature_specs)
285285
if not feature.table_spec.is_stacked()
286286
]
287287
# Amongst the not explicitly stacked features, collect the ones that point
@@ -365,7 +365,7 @@ def _populate_stacking_info_in_features(
365365
feature,
366366
)
367367

368-
for feature in tree.flatten(feature_specs):
368+
for feature in jax.tree.leaves(feature_specs):
369369
_populate_stacking_info_in_features(feature)
370370

371371

@@ -523,9 +523,9 @@ def preprocess_sparse_dense_matmul_input(
523523

524524
*csr_inputs, num_minibatches, stats = (
525525
pybind_input_preprocessing.PreprocessSparseDenseMatmulInput(
526-
tree.flatten(features),
527-
tree.flatten(features_weights),
528-
tree.flatten(feature_specs),
526+
jax.tree.leaves(features),
527+
jax.tree.leaves(features_weights),
528+
jax.tree.leaves(feature_specs),
529529
local_device_count,
530530
global_device_count,
531531
num_sc_per_device=num_sc_per_device,
@@ -629,10 +629,10 @@ def preprocess_sparse_dense_matmul_input_from_sparse_tensor(
629629

630630
*csr_inputs, num_minibatches, stats = (
631631
pybind_input_preprocessing.PreprocessSparseDenseMatmulSparseCooInput(
632-
tree.flatten(indices),
633-
tree.flatten(values),
634-
tree.flatten(dense_shapes),
635-
tree.flatten(feature_specs),
632+
jax.tree.leaves(indices),
633+
jax.tree.leaves(values),
634+
jax.tree.leaves(dense_shapes),
635+
jax.tree.leaves(feature_specs),
636636
local_device_count,
637637
global_device_count,
638638
num_sc_per_device=num_sc_per_device,
@@ -897,7 +897,8 @@ def stack_embedding_gradients(
897897
str, list[tuple[embedding_spec.FeatureSpec, jax.Array]]
898898
] = collections.defaultdict(list)
899899
for gradient, feature in zip(
900-
tree.flatten(activation_gradients), tree.flatten(feature_specs)
900+
jax.tree.leaves(activation_gradients),
901+
jax.tree.leaves(feature_specs),
901902
):
902903
if feature.id_transformation is None:
903904
raise ValueError(
@@ -1310,7 +1311,7 @@ def init_embedding_variables(
13101311
)
13111312

13121313
stacks = collections.defaultdict(list)
1313-
for table_spec in tree.flatten(table_specs):
1314+
for table_spec in jax.tree.leaves(table_specs):
13141315
stacks[table_spec.setting_in_stack.stack_name].append(table_spec)
13151316

13161317
# Make sure the table specs are sorted by their position in the stack
@@ -1372,7 +1373,7 @@ def create_proto_from_feature_specs(
13721373
str, dict[str, embedding_spec_pb2.TableSpecProto]
13731374
] = collections.defaultdict(dict)
13741375
# Traverse the feature specs and create the StackedTableSpecProto.
1375-
for feature in tree.flatten(feature_specs):
1376+
for feature in jax.tree.leaves(feature_specs):
13761377
current_stack_name = feature.table_spec.stacked_table_spec.stack_name
13771378
current_table_name = feature.table_spec.name
13781379
if current_stack_name not in stacked_table_specs:

0 commit comments

Comments
 (0)