@@ -156,7 +156,7 @@ def _verify_feature_specs(
156156) -> None :
157157 """Ensures all the fields in the feature specs are correctly defined."""
158158 visited_feature_names = set ()
159- for feature_spec in tree .flatten (feature_specs ):
159+ for feature_spec in jax . tree .leaves (feature_specs ):
160160 if feature_spec .name in visited_feature_names :
161161 raise ValueError (f"Feature spec { feature_spec .name } is already defined." )
162162 visited_feature_names .add (feature_spec .name )
@@ -166,7 +166,7 @@ def _verify_feature_specs(
166166def _verify_table_specs (table_specs : Nested [embedding_spec .TableSpec ]) -> None :
167167 """Ensures all the fields in the table specs are correctly defined."""
168168 visited_table_names = set ()
169- for table_spec in tree .flatten (table_specs ):
169+ for table_spec in jax . tree .leaves (table_specs ):
170170 if table_spec .name in visited_table_names :
171171 raise ValueError (f"Table spec { table_spec .name } is already defined." )
172172 visited_table_names .add (table_spec .name )
@@ -212,7 +212,7 @@ def get_table_specs(
212212 _verify_feature_specs (feature_specs )
213213 table_specs = {
214214 feature_spec .table_spec .name : feature_spec .table_spec
215- for feature_spec in tree .flatten (feature_specs )
215+ for feature_spec in jax . tree .leaves (feature_specs )
216216 }
217217 _verify_table_specs (table_specs )
218218 return table_specs
@@ -239,14 +239,14 @@ def get_stacked_table_specs(
239239 _verify_feature_specs (feature_specs )
240240 if any (
241241 not feature_spec .table_spec .is_stacked ()
242- for feature_spec in tree .flatten (feature_specs )
242+ for feature_spec in jax . tree .leaves (feature_specs )
243243 ):
244244 raise ValueError (
245245 "embedding.prepare_feature_specs_for_training was not called"
246246 )
247247 stacked_table_specs : list [embedding_spec .StackedTableSpec ] = [
248248 feature_spec .table_spec .stacked_table_spec
249- for feature_spec in tree .flatten (feature_specs )
249+ for feature_spec in jax . tree .leaves (feature_specs )
250250 ]
251251 return {
252252 stacked_table_specs .stack_name : stacked_table_specs
@@ -281,7 +281,7 @@ def prepare_feature_specs_for_training(
281281 num_sc_per_device = _get_num_sc_per_device (num_sc_per_device )
282282 not_stacked = [
283283 feature
284- for feature in tree .flatten (feature_specs )
284+ for feature in jax . tree .leaves (feature_specs )
285285 if not feature .table_spec .is_stacked ()
286286 ]
287287 # Amongst the not explicitly stacked features, collect the ones that point
@@ -365,7 +365,7 @@ def _populate_stacking_info_in_features(
365365 feature ,
366366 )
367367
368- for feature in tree .flatten (feature_specs ):
368+ for feature in jax . tree .leaves (feature_specs ):
369369 _populate_stacking_info_in_features (feature )
370370
371371
@@ -523,9 +523,9 @@ def preprocess_sparse_dense_matmul_input(
523523
524524 * csr_inputs , num_minibatches , stats = (
525525 pybind_input_preprocessing .PreprocessSparseDenseMatmulInput (
526- tree .flatten (features ),
527- tree .flatten (features_weights ),
528- tree .flatten (feature_specs ),
526+ jax . tree .leaves (features ),
527+ jax . tree .leaves (features_weights ),
528+ jax . tree .leaves (feature_specs ),
529529 local_device_count ,
530530 global_device_count ,
531531 num_sc_per_device = num_sc_per_device ,
@@ -629,10 +629,10 @@ def preprocess_sparse_dense_matmul_input_from_sparse_tensor(
629629
630630 * csr_inputs , num_minibatches , stats = (
631631 pybind_input_preprocessing .PreprocessSparseDenseMatmulSparseCooInput (
632- tree .flatten (indices ),
633- tree .flatten (values ),
634- tree .flatten (dense_shapes ),
635- tree .flatten (feature_specs ),
632+ jax . tree .leaves (indices ),
633+ jax . tree .leaves (values ),
634+ jax . tree .leaves (dense_shapes ),
635+ jax . tree .leaves (feature_specs ),
636636 local_device_count ,
637637 global_device_count ,
638638 num_sc_per_device = num_sc_per_device ,
@@ -897,7 +897,8 @@ def stack_embedding_gradients(
897897 str , list [tuple [embedding_spec .FeatureSpec , jax .Array ]]
898898 ] = collections .defaultdict (list )
899899 for gradient , feature in zip (
900- tree .flatten (activation_gradients ), tree .flatten (feature_specs )
900+ jax .tree .leaves (activation_gradients ),
901+ jax .tree .leaves (feature_specs ),
901902 ):
902903 if feature .id_transformation is None :
903904 raise ValueError (
@@ -1310,7 +1311,7 @@ def init_embedding_variables(
13101311 )
13111312
13121313 stacks = collections .defaultdict (list )
1313- for table_spec in tree .flatten (table_specs ):
1314+ for table_spec in jax . tree .leaves (table_specs ):
13141315 stacks [table_spec .setting_in_stack .stack_name ].append (table_spec )
13151316
13161317 # Make sure the table specs are sorted by their position in the stack
@@ -1372,7 +1373,7 @@ def create_proto_from_feature_specs(
13721373 str , dict [str , embedding_spec_pb2 .TableSpecProto ]
13731374 ] = collections .defaultdict (dict )
13741375 # Traverse the feature specs and create the StackedTableSpecProto.
1375- for feature in tree .flatten (feature_specs ):
1376+ for feature in jax . tree .leaves (feature_specs ):
13761377 current_stack_name = feature .table_spec .stacked_table_spec .stack_name
13771378 current_table_name = feature .table_spec .name
13781379 if current_stack_name not in stacked_table_specs :
0 commit comments