@@ -184,13 +184,14 @@ def test_static_buffer_size_was_too_small(self):
184184 ),
185185 ]
186186 sharding = NamedSharding (mesh , P (None , "x" , None ))
187- embedding_variables ["table" ] = tuple ([
188- jax .make_array_from_single_device_arrays (
187+ embedding_variables ["table" ] = embedding . EmbeddingVariables (
188+ table = jax .make_array_from_single_device_arrays (
189189 shape = (1000 , 8 ),
190190 sharding = sharding ,
191191 arrays = embedding_variables ["table" ],
192- )
193- ])
192+ ),
193+ slot = (),
194+ )
194195 tpu_sparse_dense_matmul = functools .partial (
195196 embedding .tpu_sparse_dense_matmul ,
196197 global_device_count = 1 ,
@@ -408,16 +409,18 @@ def test_sparse_dense_matmul_two_chips_sharded(
408409 global_device_count = mesh .size ,
409410 )
410411 if using_pmap :
411- embedding_variables ["table_a" ] = tuple ([
412+ embedding_variables ["table_a" ] = embedding . EmbeddingVariables (
412413 _create_embedding_variable_for_pmap (
413414 [VariableInfo ((32 , 8 ), 0 )], devices , mesh
414- )
415- ])
416- embedding_variables ["table_aa" ] = tuple ([
415+ ),
416+ (),
417+ )
418+ embedding_variables ["table_aa" ] = embedding .EmbeddingVariables (
417419 _create_embedding_variable_for_pmap (
418420 [VariableInfo ((32 , 8 ), 0 )], devices , mesh
419- )
420- ])
421+ ),
422+ (),
423+ )
421424 activations = jax .pmap (
422425 tpu_sparse_dense_matmul_fn ,
423426 static_broadcasted_argnums = [2 ],
@@ -427,16 +430,18 @@ def test_sparse_dense_matmul_two_chips_sharded(
427430 tuple (tree .flatten (feature_specs )),
428431 )
429432 else :
430- embedding_variables ["table_a" ] = tuple ([
431- _create_embedding_variable_for_jit (
433+ embedding_variables ["table_a" ] = embedding . EmbeddingVariables (
434+ table = _create_embedding_variable_for_jit (
432435 [VariableInfo ((32 , 8 ), 0 )], devices , mesh
433- )
434- ])
435- embedding_variables ["table_aa" ] = tuple ([
436- _create_embedding_variable_for_jit (
436+ ),
437+ slot = (),
438+ )
439+ embedding_variables ["table_aa" ] = embedding .EmbeddingVariables (
440+ table = _create_embedding_variable_for_jit (
437441 [VariableInfo ((32 , 8 ), 0 )], devices , mesh
438- )
439- ])
442+ ),
443+ slot = (),
444+ )
440445 sharded_matmul = functools .partial (
441446 tpu_sparse_dense_matmul_fn ,
442447 feature_specs = tuple (tree .flatten (feature_specs )),
@@ -534,16 +539,17 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(
534539 global_device_count = mesh .size ,
535540 )
536541 if using_pmap :
537- embedding_variables ["table_a_table_aa" ] = tuple ([
538- _create_embedding_variable_for_pmap (
542+ embedding_variables ["table_a_table_aa" ] = embedding . EmbeddingVariables (
543+ table = _create_embedding_variable_for_pmap (
539544 [
540545 VariableInfo (shape = (64 , 8 ), offset = 0 ),
541546 VariableInfo (shape = (64 , 8 ), offset = 100 ),
542547 ],
543548 devices ,
544549 mesh ,
545- )
546- ])
550+ ),
551+ slot = (),
552+ )
547553 activations = jax .pmap (
548554 tpu_sparse_dense_matmul_fn ,
549555 static_broadcasted_argnums = [2 ],
@@ -553,16 +559,17 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(
553559 tuple (tree .flatten (feature_specs )),
554560 )
555561 else :
556- embedding_variables ["table_a_table_aa" ] = tuple ([
557- _create_embedding_variable_for_jit (
562+ embedding_variables ["table_a_table_aa" ] = embedding . EmbeddingVariables (
563+ table = _create_embedding_variable_for_jit (
558564 [
559565 VariableInfo (shape = (64 , 8 ), offset = 0 ),
560566 VariableInfo (shape = (64 , 8 ), offset = 100 ),
561567 ],
562568 devices ,
563569 mesh ,
564- )
565- ])
570+ ),
571+ slot = (),
572+ )
566573 sharded_matmul = functools .partial (
567574 tpu_sparse_dense_matmul_fn ,
568575 feature_specs = tuple (tree .flatten (feature_specs )),
@@ -687,16 +694,18 @@ def test_sparse_dense_matmul_single_chip(
687694 global_device_count = mesh .size ,
688695 )
689696 if using_pmap :
690- embedding_variables ["table_a" ] = tuple ([
691- _create_embedding_variable_for_pmap (
697+ embedding_variables ["table_a" ] = embedding . EmbeddingVariables (
698+ table = _create_embedding_variable_for_pmap (
692699 [VariableInfo (shape = (32 , 8 ), offset = 0 )], devices , mesh
693- )
694- ])
695- embedding_variables ["table_b" ] = tuple ([
696- _create_embedding_variable_for_pmap (
700+ ),
701+ slot = (),
702+ )
703+ embedding_variables ["table_b" ] = embedding .EmbeddingVariables (
704+ table = _create_embedding_variable_for_pmap (
697705 [VariableInfo ((64 , 16 ), 0 )], devices , mesh
698- )
699- ])
706+ ),
707+ slot = (),
708+ )
700709 activations = jax .pmap (
701710 tpu_sparse_dense_matmul_fn ,
702711 static_broadcasted_argnums = [2 ],
@@ -706,16 +715,18 @@ def test_sparse_dense_matmul_single_chip(
706715 tuple (tree .flatten (feature_specs )),
707716 )
708717 else :
709- embedding_variables ["table_a" ] = tuple ([
710- _create_embedding_variable_for_jit (
718+ embedding_variables ["table_a" ] = embedding . EmbeddingVariables (
719+ table = _create_embedding_variable_for_jit (
711720 [VariableInfo ((32 , 8 ), 0 )], devices , mesh
712- )
713- ])
714- embedding_variables ["table_b" ] = tuple ([
715- _create_embedding_variable_for_jit (
721+ ),
722+ slot = (),
723+ )
724+ embedding_variables ["table_b" ] = embedding .EmbeddingVariables (
725+ table = _create_embedding_variable_for_jit (
716726 [VariableInfo ((64 , 16 ), 0 )], devices , mesh
717- )
718- ])
727+ ),
728+ slot = (),
729+ )
719730 sparse_matmul = jax .jit (tpu_sparse_dense_matmul_fn , static_argnums = [2 ])
720731 activations = sparse_matmul (
721732 preprocessed_inputs ,
@@ -797,16 +808,18 @@ def test_sparse_dense_matmul_two_tables(
797808 global_device_count = mesh .size ,
798809 )
799810 if using_pmap :
800- embedding_variables ["table_a" ] = tuple ([
801- _create_embedding_variable_for_pmap (
811+ embedding_variables ["table_a" ] = embedding . EmbeddingVariables (
812+ table = _create_embedding_variable_for_pmap (
802813 [VariableInfo ((32 , 8 ), 0 )], devices , mesh
803- )
804- ])
805- embedding_variables ["table_b" ] = tuple ([
806- _create_embedding_variable_for_pmap (
814+ ),
815+ slot = (),
816+ )
817+ embedding_variables ["table_b" ] = embedding .EmbeddingVariables (
818+ table = _create_embedding_variable_for_pmap (
807819 [VariableInfo ((64 , 16 ), 0 )], devices , mesh
808- )
809- ])
820+ ),
821+ slot = (),
822+ )
810823 activations = jax .pmap (
811824 tpu_sparse_dense_matmul_fn ,
812825 static_broadcasted_argnums = (2 ),
@@ -816,16 +829,18 @@ def test_sparse_dense_matmul_two_tables(
816829 tuple (tree .flatten (feature_specs )),
817830 )
818831 else :
819- embedding_variables ["table_a" ] = tuple ([
820- _create_embedding_variable_for_jit (
832+ embedding_variables ["table_a" ] = embedding . EmbeddingVariables (
833+ table = _create_embedding_variable_for_jit (
821834 [VariableInfo ((32 , 8 ), 0 )], devices , mesh
822- )
823- ])
824- embedding_variables ["table_b" ] = tuple ([
825- _create_embedding_variable_for_jit (
835+ ),
836+ slot = (),
837+ )
838+ embedding_variables ["table_b" ] = embedding .EmbeddingVariables (
839+ table = _create_embedding_variable_for_jit (
826840 [VariableInfo ((64 , 16 ), 0 )], devices , mesh
827- )
828- ])
841+ ),
842+ slot = (),
843+ )
829844 sharded_matmul = functools .partial (
830845 tpu_sparse_dense_matmul_fn ,
831846 feature_specs = tuple (tree .flatten (feature_specs )),
@@ -1327,8 +1342,8 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13271342 )
13281343 if using_pmap :
13291344 embedding_variables ["country_table_language_table_related_item_table" ] = (
1330- tuple ([
1331- _create_embedding_variable_for_pmap (
1345+ embedding . EmbeddingVariables (
1346+ table = _create_embedding_variable_for_pmap (
13321347 [
13331348 VariableInfo (shape = (256 , 16 ), offset = 0 ), # country
13341349 VariableInfo (shape = (384 , 16 ), offset = 500 ), # language
@@ -1338,8 +1353,9 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13381353 ],
13391354 devices ,
13401355 mesh ,
1341- )
1342- ])
1356+ ),
1357+ slot = (),
1358+ )
13431359 )
13441360 activations = jax .pmap (
13451361 tpu_sparse_dense_matmul_fn ,
@@ -1351,8 +1367,8 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13511367 )
13521368 else :
13531369 embedding_variables ["country_table_language_table_related_item_table" ] = (
1354- tuple ([
1355- _create_embedding_variable_for_jit (
1370+ embedding . EmbeddingVariables (
1371+ table = _create_embedding_variable_for_jit (
13561372 [
13571373 VariableInfo (shape = (256 , 16 ), offset = 0 ), # country
13581374 VariableInfo (shape = (384 , 16 ), offset = 500 ), # language
@@ -1362,8 +1378,9 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13621378 ],
13631379 devices ,
13641380 mesh ,
1365- )
1366- ])
1381+ ),
1382+ slot = (),
1383+ )
13671384 )
13681385 sharded_matmul = functools .partial (
13691386 tpu_sparse_dense_matmul_fn ,
@@ -1458,11 +1475,12 @@ def test_sparse_dense_matmul_quantized(self):
14581475 )
14591476
14601477 embedding_variables = {}
1461- embedding_variables ["quantized_table" ] = tuple ([
1462- _create_embedding_variable_for_jit (
1478+ embedding_variables ["quantized_table" ] = embedding . EmbeddingVariables (
1479+ table = _create_embedding_variable_for_jit (
14631480 [VariableInfo ((32 , 32 ), 0 )], devices , mesh
1464- )
1465- ])
1481+ ),
1482+ slot = (),
1483+ )
14661484
14671485 tpu_sparse_dense_matmul_fn = functools .partial (
14681486 embedding .tpu_sparse_dense_matmul ,
0 commit comments