Skip to content

Commit 58695a0

Browse files
Internal.
PiperOrigin-RevId: 813766282
1 parent e7ef80a commit 58695a0

File tree

4 files changed

+105
-81
lines changed

4 files changed

+105
-81
lines changed

jax_tpu_embedding/sparsecore/lib/nn/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def tpu_sparse_dense_matmul(
849849
sample_id,
850850
gain,
851851
num_minibatches,
852-
embedding_variable[0], # [0] is the embedding table
852+
embedding_variable.table,
853853
device_batch_size=stacked_table.total_sample_count
854854
// global_device_count,
855855
max_ids_per_partition=stacked_table.max_ids_per_partition,

jax_tpu_embedding/sparsecore/lib/nn/tests/tpu_sparse_dense_matmul_test.py

Lines changed: 88 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,14 @@ def test_static_buffer_size_was_too_small(self):
184184
),
185185
]
186186
sharding = NamedSharding(mesh, P(None, "x", None))
187-
embedding_variables["table"] = tuple([
188-
jax.make_array_from_single_device_arrays(
187+
embedding_variables["table"] = embedding.EmbeddingVariables(
188+
table=jax.make_array_from_single_device_arrays(
189189
shape=(1000, 8),
190190
sharding=sharding,
191191
arrays=embedding_variables["table"],
192-
)
193-
])
192+
),
193+
slot=(),
194+
)
194195
tpu_sparse_dense_matmul = functools.partial(
195196
embedding.tpu_sparse_dense_matmul,
196197
global_device_count=1,
@@ -408,16 +409,18 @@ def test_sparse_dense_matmul_two_chips_sharded(
408409
global_device_count=mesh.size,
409410
)
410411
if using_pmap:
411-
embedding_variables["table_a"] = tuple([
412+
embedding_variables["table_a"] = embedding.EmbeddingVariables(
412413
_create_embedding_variable_for_pmap(
413414
[VariableInfo((32, 8), 0)], devices, mesh
414-
)
415-
])
416-
embedding_variables["table_aa"] = tuple([
415+
),
416+
(),
417+
)
418+
embedding_variables["table_aa"] = embedding.EmbeddingVariables(
417419
_create_embedding_variable_for_pmap(
418420
[VariableInfo((32, 8), 0)], devices, mesh
419-
)
420-
])
421+
),
422+
(),
423+
)
421424
activations = jax.pmap(
422425
tpu_sparse_dense_matmul_fn,
423426
static_broadcasted_argnums=[2],
@@ -427,16 +430,18 @@ def test_sparse_dense_matmul_two_chips_sharded(
427430
tuple(tree.flatten(feature_specs)),
428431
)
429432
else:
430-
embedding_variables["table_a"] = tuple([
431-
_create_embedding_variable_for_jit(
433+
embedding_variables["table_a"] = embedding.EmbeddingVariables(
434+
table=_create_embedding_variable_for_jit(
432435
[VariableInfo((32, 8), 0)], devices, mesh
433-
)
434-
])
435-
embedding_variables["table_aa"] = tuple([
436-
_create_embedding_variable_for_jit(
436+
),
437+
slot=(),
438+
)
439+
embedding_variables["table_aa"] = embedding.EmbeddingVariables(
440+
table=_create_embedding_variable_for_jit(
437441
[VariableInfo((32, 8), 0)], devices, mesh
438-
)
439-
])
442+
),
443+
slot=(),
444+
)
440445
sharded_matmul = functools.partial(
441446
tpu_sparse_dense_matmul_fn,
442447
feature_specs=tuple(tree.flatten(feature_specs)),
@@ -534,16 +539,17 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(
534539
global_device_count=mesh.size,
535540
)
536541
if using_pmap:
537-
embedding_variables["table_a_table_aa"] = tuple([
538-
_create_embedding_variable_for_pmap(
542+
embedding_variables["table_a_table_aa"] = embedding.EmbeddingVariables(
543+
table=_create_embedding_variable_for_pmap(
539544
[
540545
VariableInfo(shape=(64, 8), offset=0),
541546
VariableInfo(shape=(64, 8), offset=100),
542547
],
543548
devices,
544549
mesh,
545-
)
546-
])
550+
),
551+
slot=(),
552+
)
547553
activations = jax.pmap(
548554
tpu_sparse_dense_matmul_fn,
549555
static_broadcasted_argnums=[2],
@@ -553,16 +559,17 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(
553559
tuple(tree.flatten(feature_specs)),
554560
)
555561
else:
556-
embedding_variables["table_a_table_aa"] = tuple([
557-
_create_embedding_variable_for_jit(
562+
embedding_variables["table_a_table_aa"] = embedding.EmbeddingVariables(
563+
table=_create_embedding_variable_for_jit(
558564
[
559565
VariableInfo(shape=(64, 8), offset=0),
560566
VariableInfo(shape=(64, 8), offset=100),
561567
],
562568
devices,
563569
mesh,
564-
)
565-
])
570+
),
571+
slot=(),
572+
)
566573
sharded_matmul = functools.partial(
567574
tpu_sparse_dense_matmul_fn,
568575
feature_specs=tuple(tree.flatten(feature_specs)),
@@ -687,16 +694,18 @@ def test_sparse_dense_matmul_single_chip(
687694
global_device_count=mesh.size,
688695
)
689696
if using_pmap:
690-
embedding_variables["table_a"] = tuple([
691-
_create_embedding_variable_for_pmap(
697+
embedding_variables["table_a"] = embedding.EmbeddingVariables(
698+
table=_create_embedding_variable_for_pmap(
692699
[VariableInfo(shape=(32, 8), offset=0)], devices, mesh
693-
)
694-
])
695-
embedding_variables["table_b"] = tuple([
696-
_create_embedding_variable_for_pmap(
700+
),
701+
slot=(),
702+
)
703+
embedding_variables["table_b"] = embedding.EmbeddingVariables(
704+
table=_create_embedding_variable_for_pmap(
697705
[VariableInfo((64, 16), 0)], devices, mesh
698-
)
699-
])
706+
),
707+
slot=(),
708+
)
700709
activations = jax.pmap(
701710
tpu_sparse_dense_matmul_fn,
702711
static_broadcasted_argnums=[2],
@@ -706,16 +715,18 @@ def test_sparse_dense_matmul_single_chip(
706715
tuple(tree.flatten(feature_specs)),
707716
)
708717
else:
709-
embedding_variables["table_a"] = tuple([
710-
_create_embedding_variable_for_jit(
718+
embedding_variables["table_a"] = embedding.EmbeddingVariables(
719+
table=_create_embedding_variable_for_jit(
711720
[VariableInfo((32, 8), 0)], devices, mesh
712-
)
713-
])
714-
embedding_variables["table_b"] = tuple([
715-
_create_embedding_variable_for_jit(
721+
),
722+
slot=(),
723+
)
724+
embedding_variables["table_b"] = embedding.EmbeddingVariables(
725+
table=_create_embedding_variable_for_jit(
716726
[VariableInfo((64, 16), 0)], devices, mesh
717-
)
718-
])
727+
),
728+
slot=(),
729+
)
719730
sparse_matmul = jax.jit(tpu_sparse_dense_matmul_fn, static_argnums=[2])
720731
activations = sparse_matmul(
721732
preprocessed_inputs,
@@ -797,16 +808,18 @@ def test_sparse_dense_matmul_two_tables(
797808
global_device_count=mesh.size,
798809
)
799810
if using_pmap:
800-
embedding_variables["table_a"] = tuple([
801-
_create_embedding_variable_for_pmap(
811+
embedding_variables["table_a"] = embedding.EmbeddingVariables(
812+
table=_create_embedding_variable_for_pmap(
802813
[VariableInfo((32, 8), 0)], devices, mesh
803-
)
804-
])
805-
embedding_variables["table_b"] = tuple([
806-
_create_embedding_variable_for_pmap(
814+
),
815+
slot=(),
816+
)
817+
embedding_variables["table_b"] = embedding.EmbeddingVariables(
818+
table=_create_embedding_variable_for_pmap(
807819
[VariableInfo((64, 16), 0)], devices, mesh
808-
)
809-
])
820+
),
821+
slot=(),
822+
)
810823
activations = jax.pmap(
811824
tpu_sparse_dense_matmul_fn,
812825
static_broadcasted_argnums=(2),
@@ -816,16 +829,18 @@ def test_sparse_dense_matmul_two_tables(
816829
tuple(tree.flatten(feature_specs)),
817830
)
818831
else:
819-
embedding_variables["table_a"] = tuple([
820-
_create_embedding_variable_for_jit(
832+
embedding_variables["table_a"] = embedding.EmbeddingVariables(
833+
table=_create_embedding_variable_for_jit(
821834
[VariableInfo((32, 8), 0)], devices, mesh
822-
)
823-
])
824-
embedding_variables["table_b"] = tuple([
825-
_create_embedding_variable_for_jit(
835+
),
836+
slot=(),
837+
)
838+
embedding_variables["table_b"] = embedding.EmbeddingVariables(
839+
table=_create_embedding_variable_for_jit(
826840
[VariableInfo((64, 16), 0)], devices, mesh
827-
)
828-
])
841+
),
842+
slot=(),
843+
)
829844
sharded_matmul = functools.partial(
830845
tpu_sparse_dense_matmul_fn,
831846
feature_specs=tuple(tree.flatten(feature_specs)),
@@ -1327,8 +1342,8 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13271342
)
13281343
if using_pmap:
13291344
embedding_variables["country_table_language_table_related_item_table"] = (
1330-
tuple([
1331-
_create_embedding_variable_for_pmap(
1345+
embedding.EmbeddingVariables(
1346+
table=_create_embedding_variable_for_pmap(
13321347
[
13331348
VariableInfo(shape=(256, 16), offset=0), # country
13341349
VariableInfo(shape=(384, 16), offset=500), # language
@@ -1338,8 +1353,9 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13381353
],
13391354
devices,
13401355
mesh,
1341-
)
1342-
])
1356+
),
1357+
slot=(),
1358+
)
13431359
)
13441360
activations = jax.pmap(
13451361
tpu_sparse_dense_matmul_fn,
@@ -1351,8 +1367,8 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13511367
)
13521368
else:
13531369
embedding_variables["country_table_language_table_related_item_table"] = (
1354-
tuple([
1355-
_create_embedding_variable_for_jit(
1370+
embedding.EmbeddingVariables(
1371+
table=_create_embedding_variable_for_jit(
13561372
[
13571373
VariableInfo(shape=(256, 16), offset=0), # country
13581374
VariableInfo(shape=(384, 16), offset=500), # language
@@ -1362,8 +1378,9 @@ def test_sparse_dense_matmul_four_chips_complex_stacked(
13621378
],
13631379
devices,
13641380
mesh,
1365-
)
1366-
])
1381+
),
1382+
slot=(),
1383+
)
13671384
)
13681385
sharded_matmul = functools.partial(
13691386
tpu_sparse_dense_matmul_fn,
@@ -1458,11 +1475,12 @@ def test_sparse_dense_matmul_quantized(self):
14581475
)
14591476

14601477
embedding_variables = {}
1461-
embedding_variables["quantized_table"] = tuple([
1462-
_create_embedding_variable_for_jit(
1478+
embedding_variables["quantized_table"] = embedding.EmbeddingVariables(
1479+
table=_create_embedding_variable_for_jit(
14631480
[VariableInfo((32, 32), 0)], devices, mesh
1464-
)
1465-
])
1481+
),
1482+
slot=(),
1483+
)
14661484

14671485
tpu_sparse_dense_matmul_fn = functools.partial(
14681486
embedding.tpu_sparse_dense_matmul,

jax_tpu_embedding/sparsecore/tests/jax_sc_shakespeare_tests.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,14 @@ def test_shakespeare_model_loss_convergence(self):
128128
for i, device in enumerate(devices)
129129
]
130130
sharding = NamedSharding(mesh, P('x', None))
131-
embedding_variables[model.table_name] = tuple([
132-
jax.make_array_from_single_device_arrays(
131+
embedding_variables[model.table_name] = embedding.EmbeddingVariables(
132+
table=jax.make_array_from_single_device_arrays(
133133
shape=(_VOCAB_SIZE.value, _EMBEDDING_SIZE.value),
134134
sharding=sharding,
135135
arrays=embedding_variables[model.table_name],
136-
)
137-
])
136+
),
137+
slot=(),
138+
)
138139

139140
# Define the forward pass function.
140141
loss_grad_fn = jax.value_and_grad(

jax_tpu_embedding/sparsecore/tests/jax_spmd_tc_with_sc_tests.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,18 @@ def setUp(self):
163163
for i, device in enumerate(self.devices)
164164
]
165165
sharding = NamedSharding(self.mesh, P('device', None))
166-
self.embedding_variables[self.shakespeare_table_spec.name] = tuple([
167-
jax.make_array_from_single_device_arrays(
168-
shape=(_VOCAB_SIZE.value, _EMBEDDING_SIZE.value),
169-
sharding=sharding,
170-
arrays=self.embedding_variables[self.shakespeare_table_spec.name],
166+
self.embedding_variables[self.shakespeare_table_spec.name] = (
167+
embedding.EmbeddingVariables(
168+
table=jax.make_array_from_single_device_arrays(
169+
shape=(_VOCAB_SIZE.value, _EMBEDDING_SIZE.value),
170+
sharding=sharding,
171+
arrays=self.embedding_variables[
172+
self.shakespeare_table_spec.name
173+
],
174+
),
175+
slot=(),
171176
)
172-
])
177+
)
173178
# Construct the model.
174179
self.model = ShakespeareSpmdModel(
175180
vocab_size=_VOCAB_SIZE.value,

0 commit comments

Comments
 (0)