Skip to content

Commit 376edd4

Browse files
committed
avoid oom
1 parent d8bbfc9 commit 376edd4

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

lmdeploy/pytorch/model_inputs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ def _gather_tp_sizes(tp: int, seqlen: int, dist_ctx: dist.DistContext, layer_typ
2828
if tp > 1 and tp != attn_tp:
2929
dist_group = dist.get_dist_group(layer_type=layer_type)
3030
gather_group = dist_group.gpu_gather_group
31-
tp_sizes = [None for _ in range(gather_group.size())]
32-
dist.all_gather_object(tp_sizes, seqlen, group=gather_group)
31+
rank = gather_group.rank()
32+
tp_size_tensor = torch.zeros(gather_group.size(), dtype=torch.int32, device='cuda')
33+
tp_size_tensor[rank].fill_(seqlen)
34+
dist.all_gather_into_tensor(tp_size_tensor, tp_size_tensor[rank], group=gather_group)
35+
tp_sizes = tp_size_tensor.tolist()
3336
else:
3437
tp_sizes = [seqlen]
3538
return tp_sizes

lmdeploy/pytorch/nn/moe.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 4096):
8484
def all_gather(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
8585
tp_sizes: List[int]):
8686
"""All gather."""
87-
hidden_states, _ = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True)
88-
topk_weights, _ = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True)
89-
topk_ids, handle = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True)
90-
return hidden_states, topk_weights, topk_ids, handle
87+
hidden_states, h0 = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True)
88+
topk_weights, h1 = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True)
89+
topk_ids, h2 = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True)
90+
return hidden_states, topk_weights, topk_ids, (h0, h1, h2)
9191

9292
def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]):
9393
"""Reduce scatter."""
@@ -100,9 +100,10 @@ def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor,
100100
return out_states, handle
101101

102102
def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
103-
output_states: torch.Tensor, tp_sizes: List[int], handle: dist.Work):
103+
output_states: torch.Tensor, tp_sizes: List[int], handles: List[dist.Work]):
104104
"""Gemm and reduce scatter."""
105-
handle.wait()
105+
for handle in handles:
106+
handle.wait()
106107
cur_out = self.gemm_func(hidden_states, topk_weights, topk_ids)
107108
return self.reduce_scatter(cur_out, output_states, tp_sizes)
108109

@@ -129,13 +130,13 @@ def __slice_and_gather():
129130
cur_output, output_states = __slice_tensor(output_states, slice_size)
130131

131132
# all gather
132-
cur_hidden_states, cur_topk_weights, cur_topk_ids, handle = self.all_gather(
133+
cur_hidden_states, cur_topk_weights, cur_topk_ids, handles = self.all_gather(
133134
cur_hidden_states, cur_topk_weights, cur_topk_ids, cur_tp_sizes)
134135
return dict(hidden_states=cur_hidden_states,
135136
topk_weights=cur_topk_weights,
136137
topk_ids=cur_topk_ids,
137138
output_states=cur_output,
138-
handle=handle,
139+
handles=handles,
139140
tp_sizes=cur_tp_sizes)
140141

141142
step_ctx = get_step_ctx_manager().current_context()
@@ -149,15 +150,19 @@ def __slice_and_gather():
149150
# pre
150151
cur_inputs = __slice_and_gather()
151152

153+
out_handles = []
152154
# main loop
153155
while tp_sizes.sum() > 0:
154156
next_inputs = __slice_and_gather()
155-
self._gemm_and_reduce_scatter(**cur_inputs)
157+
_, handle = self._gemm_and_reduce_scatter(**cur_inputs)
158+
out_handles.append(handle)
156159
cur_inputs = next_inputs
157160

158161
# post
159162
_, handle = self._gemm_and_reduce_scatter(**cur_inputs)
160-
handle.wait()
163+
out_handles.append(handle)
164+
for handle in out_handles:
165+
handle.wait()
161166
return return_states
162167

163168

0 commit comments

Comments
 (0)