Skip to content

Commit cf7c811

Browse files
authored
fix: return 参数数量 及 返回类型考虑 (#2129)
1 parent 6dda206 commit cf7c811

File tree

1 file changed

+35
-28
lines changed

1 file changed

+35
-28
lines changed

request_llms/oai_std_model_template.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import json
22
import time
33
import traceback
4+
45
import requests
56
from loguru import logger
67

78
# config_private.py放自己的秘密如API和代理网址
89
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
9-
from toolbox import (
10-
get_conf,
11-
update_ui,
12-
is_the_upload_folder,
13-
)
10+
from toolbox import get_conf, is_the_upload_folder, update_ui
1411

1512
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf(
1613
"proxies", "TIMEOUT_SECONDS", "MAX_RETRY"
@@ -39,27 +36,35 @@ def decode_chunk(chunk):
3936
用于解读"content"和"finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
4037
"""
4138
chunk = chunk.decode()
42-
respose = ""
39+
response = ""
4340
reasoning_content = ""
4441
finish_reason = "False"
42+
43+
# 考虑返回类型是 text/json 和 text/event-stream 两种
44+
if chunk.startswith("data: "):
45+
chunk = chunk[6:]
46+
else:
47+
chunk = chunk
48+
4549
try:
46-
chunk = json.loads(chunk[6:])
50+
chunk = json.loads(chunk)
4751
except:
48-
respose = ""
52+
response = ""
4953
finish_reason = chunk
54+
5055
# 错误处理部分
5156
if "error" in chunk:
52-
respose = "API_ERROR"
57+
response = "API_ERROR"
5358
try:
5459
chunk = json.loads(chunk)
5560
finish_reason = chunk["error"]["code"]
5661
except:
5762
finish_reason = "API_ERROR"
58-
return respose, finish_reason
63+
return response, reasoning_content, finish_reason
5964

6065
try:
6166
if chunk["choices"][0]["delta"]["content"] is not None:
62-
respose = chunk["choices"][0]["delta"]["content"]
67+
response = chunk["choices"][0]["delta"]["content"]
6368
except:
6469
pass
6570
try:
@@ -71,7 +76,7 @@ def decode_chunk(chunk):
7176
finish_reason = chunk["choices"][0]["finish_reason"]
7277
except:
7378
pass
74-
return respose, reasoning_content, finish_reason
79+
return response, reasoning_content, finish_reason
7580

7681

7782
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
@@ -106,15 +111,15 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
106111
what_i_ask_now["role"] = "user"
107112
what_i_ask_now["content"] = input
108113
messages.append(what_i_ask_now)
109-
playload = {
114+
payload = {
110115
"model": model,
111116
"messages": messages,
112117
"temperature": temperature,
113118
"stream": True,
114119
"max_tokens": max_output_token,
115120
}
116121

117-
return headers, playload
122+
return headers, payload
118123

119124

120125
def get_predict_function(
@@ -141,7 +146,7 @@ def predict_no_ui_long_connection(
141146
history=[],
142147
sys_prompt="",
143148
observe_window=None,
144-
console_slience=False,
149+
console_silence=False,
145150
):
146151
"""
147152
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
@@ -162,7 +167,7 @@ def predict_no_ui_long_connection(
162167
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
163168
if inputs == "":
164169
inputs = "你好👋"
165-
headers, playload = generate_message(
170+
headers, payload = generate_message(
166171
input=inputs,
167172
model=llm_kwargs["llm_model"],
168173
key=APIKEY,
@@ -182,7 +187,7 @@ def predict_no_ui_long_connection(
182187
endpoint,
183188
headers=headers,
184189
proxies=None if disable_proxy else proxies,
185-
json=playload,
190+
json=payload,
186191
stream=True,
187192
timeout=TIMEOUT_SECONDS,
188193
)
@@ -198,7 +203,7 @@ def predict_no_ui_long_connection(
198203
result = ""
199204
finish_reason = ""
200205
if reasoning:
201-
resoning_buffer = ""
206+
reasoning_buffer = ""
202207

203208
stream_response = response.iter_lines()
204209
while True:
@@ -226,12 +231,12 @@ def predict_no_ui_long_connection(
226231
if chunk:
227232
try:
228233
if finish_reason == "stop":
229-
if not console_slience:
234+
if not console_silence:
230235
print(f"[response] {result}")
231236
break
232237
result += response_text
233238
if reasoning:
234-
resoning_buffer += reasoning_content
239+
reasoning_buffer += reasoning_content
235240
if observe_window is not None:
236241
# 观测窗,把已经获取的数据显示出去
237242
if len(observe_window) >= 1:
@@ -247,9 +252,9 @@ def predict_no_ui_long_connection(
247252
logger.error(error_msg)
248253
raise RuntimeError("Json解析不合常规")
249254
if reasoning:
250-
# reasoning 的部分加上框 (>)
251-
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + \
252-
'\n\n' + result
255+
return f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
256+
{''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in reasoning_buffer.split('\n')])}
257+
</div>\n\n''' + result
253258
return result
254259

255260
def predict(
@@ -268,7 +273,7 @@ def predict(
268273
inputs 是本次问询的输入
269274
top_p, temperature是chatGPT的内部调优参数
270275
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
271-
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
276+
chatbot 为WebUI中显示的对话列表,修改它,然后yield出去,可以直接修改对话界面内容
272277
additional_fn代表点击的哪个按钮,按钮见functional.py
273278
"""
274279
from .bridge_all import model_info
@@ -299,7 +304,7 @@ def predict(
299304
) # 刷新界面
300305
time.sleep(2)
301306

302-
headers, playload = generate_message(
307+
headers, payload = generate_message(
303308
input=inputs,
304309
model=llm_kwargs["llm_model"],
305310
key=APIKEY,
@@ -321,7 +326,7 @@ def predict(
321326
endpoint,
322327
headers=headers,
323328
proxies=None if disable_proxy else proxies,
324-
json=playload,
329+
json=payload,
325330
stream=True,
326331
timeout=TIMEOUT_SECONDS,
327332
)
@@ -367,7 +372,7 @@ def predict(
367372
chunk_decoded = chunk.decode()
368373
chatbot[-1] = (
369374
chatbot[-1][0],
370-
"[Local Message] {finish_reason},获得以下报错信息:\n"
375+
f"[Local Message] {finish_reason},获得以下报错信息:\n"
371376
+ chunk_decoded,
372377
)
373378
yield from update_ui(
@@ -385,7 +390,9 @@ def predict(
385390
if reasoning:
386391
gpt_replying_buffer += response_text
387392
gpt_reasoning_buffer += reasoning_content
388-
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
393+
history[-1] = f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
394+
{''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in gpt_reasoning_buffer.split('\n')])}
395+
</div>\n\n''' + gpt_replying_buffer
389396
else:
390397
gpt_replying_buffer += response_text
391398
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出

0 commit comments

Comments
 (0)