@@ -84,10 +84,10 @@ def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 4096):
8484 def all_gather (self , hidden_states : torch .Tensor , topk_weights : torch .Tensor , topk_ids : torch .Tensor ,
8585 tp_sizes : List [int ]):
8686 """All gather."""
87- hidden_states , _ = dist .gather_by_tp_sizes (hidden_states , tp_sizes , group = self .gather_group , async_op = True )
88- topk_weights , _ = dist .gather_by_tp_sizes (topk_weights , tp_sizes , group = self .gather_group , async_op = True )
89- topk_ids , handle = dist .gather_by_tp_sizes (topk_ids , tp_sizes , group = self .gather_group , async_op = True )
90- return hidden_states , topk_weights , topk_ids , handle
87+ hidden_states , h0 = dist .gather_by_tp_sizes (hidden_states , tp_sizes , group = self .gather_group , async_op = True )
88+ topk_weights , h1 = dist .gather_by_tp_sizes (topk_weights , tp_sizes , group = self .gather_group , async_op = True )
89+ topk_ids , h2 = dist .gather_by_tp_sizes (topk_ids , tp_sizes , group = self .gather_group , async_op = True )
90+ return hidden_states , topk_weights , topk_ids , ( h0 , h1 , h2 )
9191
9292 def reduce_scatter (self , hidden_states : torch .Tensor , out_states : torch .Tensor , tp_sizes : List [int ]):
9393 """Reduce scatter."""
@@ -100,9 +100,10 @@ def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor,
100100 return out_states , handle
101101
102102 def _gemm_and_reduce_scatter (self , hidden_states : torch .Tensor , topk_weights : torch .Tensor , topk_ids : torch .Tensor ,
103- output_states : torch .Tensor , tp_sizes : List [int ], handle : dist .Work ):
103+ output_states : torch .Tensor , tp_sizes : List [int ], handles : List [ dist .Work ] ):
104104 """Gemm and reduce scatter."""
105- handle .wait ()
105+ for handle in handles :
106+ handle .wait ()
106107 cur_out = self .gemm_func (hidden_states , topk_weights , topk_ids )
107108 return self .reduce_scatter (cur_out , output_states , tp_sizes )
108109
@@ -129,13 +130,13 @@ def __slice_and_gather():
129130 cur_output , output_states = __slice_tensor (output_states , slice_size )
130131
131132 # all gather
132- cur_hidden_states , cur_topk_weights , cur_topk_ids , handle = self .all_gather (
133+ cur_hidden_states , cur_topk_weights , cur_topk_ids , handles = self .all_gather (
133134 cur_hidden_states , cur_topk_weights , cur_topk_ids , cur_tp_sizes )
134135 return dict (hidden_states = cur_hidden_states ,
135136 topk_weights = cur_topk_weights ,
136137 topk_ids = cur_topk_ids ,
137138 output_states = cur_output ,
138- handle = handle ,
139+ handles = handles ,
139140 tp_sizes = cur_tp_sizes )
140141
141142 step_ctx = get_step_ctx_manager ().current_context ()
@@ -149,15 +150,19 @@ def __slice_and_gather():
149150 # pre
150151 cur_inputs = __slice_and_gather ()
151152
153+ out_handles = []
152154 # main loop
153155 while tp_sizes .sum () > 0 :
154156 next_inputs = __slice_and_gather ()
155- self ._gemm_and_reduce_scatter (** cur_inputs )
157+ _ , handle = self ._gemm_and_reduce_scatter (** cur_inputs )
158+ out_handles .append (handle )
156159 cur_inputs = next_inputs
157160
158161 # post
159162 _ , handle = self ._gemm_and_reduce_scatter (** cur_inputs )
160- handle .wait ()
163+ out_handles .append (handle )
164+ for handle in out_handles :
165+ handle .wait ()
161166 return return_states
162167
163168
0 commit comments