Skip to content

Commit 4cb39c1

Browse files
faran928meta-codesync[bot]
authored andcommitted
Support for uneven heterogenous sharding for inference sharded tensor pool (#3533)
Summary: Pull Request resolved: #3533 A few changes in the diff: 1. Support to proportionally shard the tensor pool based on memory capacity per rank. 2. Using block_bucketize_sparse_features_inference to return bucket_mapping that can be used during request batching in inference w/ custom sigrid predictor engine 3. Wrapping some of the operations with fx wrappers to make it compatible with model split boundaries for DLRM serving where embeddings are sharded and split onto different pytorch modules 4. Exposing set_device() api to some of the modules if we want to place some shards to cpu while others to cuda. 5. Move _get_unbucketize_tensor_via_length_alignment to common util files. As part of this diff, also had to update some of the test cases. Mainly because updating the forward path a bit leads to reorganization of return values of remote module in the test cases leading to reorganization of batchinfo for each of those output. Baseline test Full Output: https://www.internalfb.com/intern/everpaste/?handle=GMpN0B-VwcKcvAgDAB_hIkOvKrJNbsIXAAAB&phabricator_paste_number=2035160680 Remote graph: https://www.internalfb.com/intern/everpaste/?color=0&handle=GDe4BCBV-XwTNLoEAOTnVbCneb1jbr0LAAAz Output order: (_item_embedding_index_values_tensor_pool__local_shard_pools_0, _item_embedding_index_values_tensor_pool__local_shard_pools_1, getitem_6, getitem_10, getitem_9) After changes Full Output: https://www.internalfb.com/intern/everpaste/?handle=GIb_pSL6TjHawBgEALDgarUck4YhbsIXAAAB&phabricator_paste_number=2035191658 Remote graph: https://www.internalfb.com/intern/everpaste/?color=0&handle=GFmFkB9waX3elzMGAJr9zTcLZiIDbr0LAAAz Output Order: getitem_6, _item_embedding_index_values_tensor_pool__local_shard_pools_0, _item_embedding_index_values_tensor_pool__local_shard_pools_1, getitem_10, getitem_9 getitem_6 is shifted first after changes. Reviewed By: jiayisuse Differential Revision: D79603009 fbshipit-source-id: 2da7e7e40c8569ba543360e022098debd278fb14
1 parent 115aaa8 commit 4cb39c1

File tree

9 files changed

+529
-37
lines changed

9 files changed

+529
-37
lines changed

torchrec/distributed/keyed_jagged_tensor_pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def create_context(self) -> ObjectPoolShardingContext:
608608
def _lookup_ids_dist(
609609
self,
610610
ids: torch.Tensor,
611-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
611+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
612612
return self._lookup_ids_dist_impl(ids)
613613

614614
# pyre-ignore
@@ -630,7 +630,7 @@ def _lookup_values_dist(
630630

631631
# pyre-ignore
632632
def forward(self, ids: torch.Tensor) -> KeyedJaggedTensor:
633-
dist_input, unbucketize_permute = self._lookup_ids_dist(ids)
633+
dist_input, unbucketize_permute, _, _ = self._lookup_ids_dist(ids)
634634
lookup = self._lookup_local(dist_input)
635635
# Here we are playing a trick to workaround a fx tracing issue,
636636
# as proxy is not iteratable.

torchrec/distributed/quant_embedding.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from torchrec.modules.utils import (
8585
_fx_trec_get_feature_length,
8686
_get_batching_hinted_output,
87+
_get_unbucketize_tensor_via_length_alignment,
8788
)
8889
from torchrec.quant.embedding_modules import (
8990
EmbeddingCollection as QuantEmbeddingCollection,
@@ -96,6 +97,7 @@
9697
torch.fx.wrap("len")
9798
torch.fx.wrap("_get_batching_hinted_output")
9899
torch.fx.wrap("_fx_trec_get_feature_length")
100+
torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")
99101

100102
try:
101103
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -278,16 +280,6 @@ def _fx_trec_wrap_length_tolist(length: torch.Tensor) -> List[int]:
278280
return length.long().tolist()
279281

280282

281-
@torch.fx.wrap
282-
def _get_unbucketize_tensor_via_length_alignment(
283-
lengths: torch.Tensor,
284-
bucketize_length: torch.Tensor,
285-
bucketize_permute_tensor: torch.Tensor,
286-
bucket_mapping_tensor: torch.Tensor,
287-
) -> torch.Tensor:
288-
return bucketize_permute_tensor
289-
290-
291283
@torch.fx.wrap
292284
def _fx_split_embeddings_per_feature_length(
293285
embeddings: torch.Tensor,

torchrec/distributed/sharding/rw_pool_sharding.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
166166
block_size (torch.Tensor): tensor containing block sizes for each rank.
167167
e.g. if block_size=torch.tensor(100), then IDs 0-99 will be assigned to rank
168168
0, 100-199 to rank 1, and so on.
169+
block_bucketize_row_pos (torch.Tensor]): tensor containing shard/row offsets for each
170+
rank in case of uneven sharding of the tensor pool across ranks. If not provided,
171+
then block_size will be used to permute the IDs across ranks.
169172
170173
Example:
171174
device = torch.device("cpu")
@@ -179,22 +182,27 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
179182
_world_size: int
180183
_device: torch.device
181184
_block_size: torch.Tensor
185+
_block_bucketize_row_pos: list[torch.Tensor]
182186

183187
def __init__(
184188
self,
185189
env: ShardingEnv,
186190
device: torch.device,
187191
block_size: torch.Tensor,
192+
block_bucketize_row_pos: Optional[list[torch.Tensor]] = None,
188193
) -> None:
189194
super().__init__()
190195
self._world_size = env.world_size
191196
self._device = device
192197
self._block_size = block_size
198+
self._block_bucketize_row_pos: list[torch.Tensor] = (
199+
[] if block_bucketize_row_pos is None else block_bucketize_row_pos
200+
)
193201

194202
def forward(
195203
self,
196204
ids: torch.Tensor,
197-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
205+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
198206
"""
199207
Bucketizes ids tensor into a list of tensors each containing ids
200208
for the corresponding rank. Places each tensor on the appropriate device.
@@ -203,24 +211,34 @@ def forward(
203211
ids (torch.Tensor): Tensor with ids
204212
205213
Returns:
206-
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing list of ids tensors
207-
for each rank given the bucket sizes, and the tensor containing indices
208-
to permute the ids to get the original order before bucketization.
214+
Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
215+
Tuple containing
216+
1. list of ids tensors for each rank given the bucket sizes
217+
2. the tensor containing indices to permute the ids to get the original order before bucketization.
218+
3. the tensor containing the bucket mapping for each id
219+
4. the tensor containing the bucketized lengths
209220
"""
210221
(
211222
bucketized_lengths,
212223
bucketized_indices,
213-
_bucketized_weights,
214-
_bucketize_permute,
224+
_, # bucketized_weights
225+
_, # _bucketize_permute
215226
unbucketize_permute,
216-
) = torch.ops.fbgemm.block_bucketize_sparse_features(
217-
_get_bucketize_shape(ids, ids.device),
218-
ids.long(),
227+
bucket_mapping,
228+
) = torch.ops.fbgemm.block_bucketize_sparse_features_inference(
229+
lengths=_get_bucketize_shape(ids, ids.device),
230+
indices=ids.long(),
219231
bucketize_pos=False,
220232
sequence=True,
221233
block_sizes=self._block_size.long(),
222234
my_size=self._world_size,
223235
weights=None,
236+
block_bucketize_pos=(
237+
self._block_bucketize_row_pos
238+
if len(self._block_bucketize_row_pos) > 0
239+
else None
240+
),
241+
return_bucket_mapping=True,
224242
)
225243

226244
id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)
@@ -236,7 +254,13 @@ def forward(
236254
)
237255

238256
assert unbucketize_permute is not None, "unbucketize permute must not be None"
239-
return dist_ids, unbucketize_permute
257+
assert bucket_mapping is not None, "bucket mapping must not be None"
258+
return (
259+
dist_ids,
260+
unbucketize_permute,
261+
bucket_mapping,
262+
bucketized_lengths,
263+
)
240264

241265
def update(
242266
self,
@@ -270,6 +294,11 @@ def update(
270294
block_sizes=self._block_size.long(),
271295
my_size=self._world_size,
272296
weights=None,
297+
block_bucketize_pos=(
298+
self._block_bucketize_row_pos
299+
if len(self._block_bucketize_row_pos) > 0
300+
else None
301+
),
273302
)
274303

275304
id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)

torchrec/distributed/sharding/rw_tensor_pool_sharding.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ class InferRwTensorPoolOutputDist(torch.nn.Module):
213213
vals = torch.Tensor([1,2,3,4,5,6], device=device)
214214
"""
215215

216+
__annotations__ = {"_device": Optional[torch.device]}
217+
216218
def __init__(
217219
self,
218220
env: ShardingEnv,
@@ -224,6 +226,11 @@ def __init__(
224226
self._cat_dim = 0
225227
self._placeholder: torch.Tensor = torch.ones(1, device=device)
226228

229+
@torch.jit.export
230+
def set_device(self, device_str: str) -> None:
231+
self._device = torch.device(device_str)
232+
self._placeholder = torch.ones(1, device=self._device)
233+
227234
def forward(
228235
self,
229236
lookups: List[torch.Tensor],
@@ -256,12 +263,16 @@ def __init__(
256263
pool_size: int,
257264
env: ShardingEnv,
258265
device: torch.device,
266+
memory_capacity_per_rank: Optional[list[int]] = None,
259267
) -> None:
260-
super().__init__(pool_size, env, device)
268+
super().__init__(pool_size, env, device, memory_capacity_per_rank)
261269

262270
def create_lookup_ids_dist(self) -> InferRwObjectPoolInputDist:
263271
return InferRwObjectPoolInputDist(
264-
self._env, device=self._device, block_size=self._block_size_t
272+
self._env,
273+
device=self._device,
274+
block_size=self._block_size_t,
275+
block_bucketize_row_pos=self._block_bucketize_row_pos,
265276
)
266277

267278
def create_lookup_values_dist(

torchrec/distributed/tensor_pool.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@
3232
)
3333
from torchrec.modules.object_pool_lookups import TensorLookup, TensorPoolLookup
3434
from torchrec.modules.tensor_pool import TensorPool
35-
from torchrec.modules.utils import deterministic_dedup
35+
from torchrec.modules.utils import (
36+
_get_batching_hinted_output,
37+
_get_unbucketize_tensor_via_length_alignment,
38+
deterministic_dedup,
39+
)
40+
41+
torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")
42+
torch.fx.wrap("_get_batching_hinted_output")
3643

3744

3845
@torch.fx.wrap
@@ -44,6 +51,17 @@ def index_select_view(
4451
return output[unbucketize_permute].view(-1, dim)
4552

4653

54+
@torch.fx.wrap
55+
def _fx_item_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor:
56+
assert optional is not None, "Expected optional to be non-None Tensor"
57+
return optional
58+
59+
60+
@torch.fx.wrap
61+
def _get_id_length_sharded_tensor_pool(ids: torch.Tensor) -> torch.Tensor:
62+
return torch.tensor([ids.size(dim=0)], device=ids.device, dtype=torch.long)
63+
64+
4765
class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]):
4866
def __init__(
4967
self,
@@ -271,6 +289,8 @@ class LocalShardPool(torch.nn.Module):
271289
# out is tensor([1,2,3]) i.e. first row of the shard
272290
"""
273291

292+
current_device: torch.device
293+
274294
def __init__(
275295
self,
276296
shard: torch.Tensor,
@@ -280,6 +300,12 @@ def __init__(
280300
shard,
281301
requires_grad=False,
282302
)
303+
self.current_device = self._shard.device
304+
305+
@torch.jit.export
306+
def set_device(self, device_str: str) -> None:
307+
self.current_device = torch.device(device_str)
308+
self._shard.to(self.current_device)
283309

284310
def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
285311
"""
@@ -291,7 +317,7 @@ def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
291317
Returns:
292318
torch.Tensor: Tensor of values corresponding to the given rank ids.
293319
"""
294-
return self._shard[rank_ids]
320+
return self._shard[rank_ids.to(self.current_device)]
295321

296322
def update(self, rank_ids: torch.Tensor, values: torch.Tensor) -> None:
297323
_ = update(self._shard, rank_ids, values)
@@ -337,6 +363,11 @@ def __init__(
337363
env=self._sharding_env,
338364
device=self._device,
339365
pool_size=self._pool_size,
366+
memory_capacity_per_rank=(
367+
self._sharding_plan.memory_capacity_per_rank
368+
if self._sharding_plan.memory_capacity_per_rank is not None
369+
else None
370+
),
340371
)
341372
else:
342373
raise NotImplementedError(
@@ -356,6 +387,7 @@ def __init__(
356387
if device == torch.device("cpu")
357388
else torch.device("cuda", rank)
358389
)
390+
359391
self._local_shard_pools.append(
360392
LocalShardPool(
361393
torch.empty(
@@ -409,7 +441,7 @@ def create_context(self) -> ObjectPoolShardingContext:
409441
def _lookup_ids_dist(
410442
self,
411443
ids: torch.Tensor,
412-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
444+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
413445
return self._lookup_ids_dist_impl(ids)
414446

415447
# pyre-ignore
@@ -439,18 +471,54 @@ def _lookup_values_dist(
439471

440472
# pyre-ignore
441473
def forward(self, ids: torch.Tensor) -> torch.Tensor:
442-
dist_input, unbucketize_permute = self._lookup_ids_dist(ids)
474+
dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = (
475+
self._lookup_ids_dist(ids)
476+
)
477+
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
478+
unbucketize_permute
479+
)
480+
443481
lookup = self._lookup_local(dist_input)
444482

445483
# Here we are playing a trick to workaround a fx tracing issue,
446484
# as proxy is not iteratable.
447485
lookup_list = []
448-
for i in range(self._world_size):
449-
lookup_list.append(lookup[i])
486+
# In case of non-heterogenous even sharding keeping the behavior
487+
# consistent with existing logic to ensure that additional fx wrappers
488+
# do not impact the model split logic during inference in anyway
489+
if self._sharding_plan.memory_capacity_per_rank is None:
490+
for i in range(self._world_size):
491+
lookup_list.append(lookup[i])
492+
else:
493+
# Adding fx wrappers in case of uneven heterogenous sharding to
494+
# make it compatible with model split boundaries during inference
495+
for i in range(self._world_size):
496+
lookup_list.append(
497+
_get_batching_hinted_output(
498+
_get_id_length_sharded_tensor_pool(dist_input[i]), lookup[i]
499+
)
500+
)
501+
502+
features_before_input_dist_length = _get_id_length_sharded_tensor_pool(ids)
503+
bucketized_lengths_col_view = bucketized_lengths.view(self._world_size, -1)
504+
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
505+
unbucketize_permute
506+
)
507+
bucket_mapping_non_opt = _fx_item_unwrap_optional_tensor(bucket_mapping)
508+
unbucketize_permute_non_opt = _get_unbucketize_tensor_via_length_alignment(
509+
features_before_input_dist_length,
510+
bucketized_lengths_col_view,
511+
unbucketize_permute_non_opt,
512+
bucket_mapping_non_opt,
513+
)
450514

451515
output = self._lookup_values_dist(lookup_list)
452516

453-
return index_select_view(output, unbucketize_permute, self._dim)
517+
return index_select_view(
518+
output,
519+
unbucketize_permute_non_opt.to(device=output.device),
520+
self._dim,
521+
)
454522

455523
# pyre-ignore
456524
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):

0 commit comments

Comments
 (0)