Skip to content

Commit dc28b85

Browse files
chengyumamcyky
andauthored
fix: handle non-tuple decoder outputs during Qwen-2.5 quantization (InternLM#4158)
* fix: handle non-tuple decoder outputs during Qwen-2.5 quantization * fix lint --------- Co-authored-by: machengyu <[email protected]>
1 parent f363bab commit dc28b85

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

lmdeploy/lite/utils/batch_split.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,23 @@ def split_decoder_layer_inputs(batch_size, *args: Union[torch.Tensor, Any],
5858
return batch_args, batch_kwargs
5959

6060

61-
def concat_decoder_layer_outputs(batch_outputs: List[Tuple[Any]]) -> Tuple[Any]:
61+
def concat_decoder_layer_outputs(batch_outputs: List[Any]) -> Any:
6262
"""This function concatenates individual decoder layer outputs into a
6363
batched output.
6464
6565
Args:
66-
batch_outputs (List[Tuple[Any]]): A list of tuples, where each tuple
66+
batch_outputs (List[Any]): A list, where each tuple
6767
represents the output from an individual element in the batch.
6868
6969
Returns:
70-
Tuple[Any]: A tuple representing the batched output.
70+
Any: Batched output.
7171
"""
7272

73+
output_is_tuple = True
74+
if not isinstance(batch_outputs[0], tuple):
75+
output_is_tuple = False
76+
batch_outputs = [(output, ) for output in batch_outputs]
77+
7378
num_returns = len(batch_outputs[0])
7479

7580
def is_past_key_value(data: Any) -> bool:
@@ -105,4 +110,7 @@ def is_past_key_value(data: Any) -> bool:
105110
out_i = torch.cat([out[i] for out in batch_outputs])
106111
new_outputs.append(out_i)
107112

108-
return tuple(new_outputs)
113+
if output_is_tuple:
114+
return tuple(new_outputs)
115+
else:
116+
return new_outputs[0]

0 commit comments

Comments
 (0)