1
1
import json
2
2
import time
3
3
import traceback
4
+
4
5
import requests
5
6
from loguru import logger
6
7
7
8
# config_private.py放自己的秘密如API和代理网址
8
9
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
9
- from toolbox import (
10
- get_conf ,
11
- update_ui ,
12
- is_the_upload_folder ,
13
- )
10
+ from toolbox import get_conf , is_the_upload_folder , update_ui
14
11
15
12
proxies , TIMEOUT_SECONDS , MAX_RETRY = get_conf (
16
13
"proxies" , "TIMEOUT_SECONDS" , "MAX_RETRY"
@@ -39,27 +36,35 @@ def decode_chunk(chunk):
39
36
用于解读"content"和"finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
40
37
"""
41
38
chunk = chunk .decode ()
42
- respose = ""
39
+ response = ""
43
40
reasoning_content = ""
44
41
finish_reason = "False"
42
+
43
+ # 考虑返回类型是 text/json 和 text/event-stream 两种
44
+ if chunk .startswith ("data: " ):
45
+ chunk = chunk [6 :]
46
+ else :
47
+ chunk = chunk
48
+
45
49
try :
46
- chunk = json .loads (chunk [ 6 :] )
50
+ chunk = json .loads (chunk )
47
51
except :
48
- respose = ""
52
+ response = ""
49
53
finish_reason = chunk
54
+
50
55
# 错误处理部分
51
56
if "error" in chunk :
52
- respose = "API_ERROR"
57
+ response = "API_ERROR"
53
58
try :
54
59
chunk = json .loads (chunk )
55
60
finish_reason = chunk ["error" ]["code" ]
56
61
except :
57
62
finish_reason = "API_ERROR"
58
- return respose , finish_reason
63
+ return response , reasoning_content , finish_reason
59
64
60
65
try :
61
66
if chunk ["choices" ][0 ]["delta" ]["content" ] is not None :
62
- respose = chunk ["choices" ][0 ]["delta" ]["content" ]
67
+ response = chunk ["choices" ][0 ]["delta" ]["content" ]
63
68
except :
64
69
pass
65
70
try :
@@ -71,7 +76,7 @@ def decode_chunk(chunk):
71
76
finish_reason = chunk ["choices" ][0 ]["finish_reason" ]
72
77
except :
73
78
pass
74
- return respose , reasoning_content , finish_reason
79
+ return response , reasoning_content , finish_reason
75
80
76
81
77
82
def generate_message (input , model , key , history , max_output_token , system_prompt , temperature ):
@@ -106,15 +111,15 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
106
111
what_i_ask_now ["role" ] = "user"
107
112
what_i_ask_now ["content" ] = input
108
113
messages .append (what_i_ask_now )
109
- playload = {
114
+ payload = {
110
115
"model" : model ,
111
116
"messages" : messages ,
112
117
"temperature" : temperature ,
113
118
"stream" : True ,
114
119
"max_tokens" : max_output_token ,
115
120
}
116
121
117
- return headers , playload
122
+ return headers , payload
118
123
119
124
120
125
def get_predict_function (
@@ -141,7 +146,7 @@ def predict_no_ui_long_connection(
141
146
history = [],
142
147
sys_prompt = "" ,
143
148
observe_window = None ,
144
- console_slience = False ,
149
+ console_silence = False ,
145
150
):
146
151
"""
147
152
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
@@ -162,7 +167,7 @@ def predict_no_ui_long_connection(
162
167
raise RuntimeError (f"APIKEY为空,请检查配置文件的{ APIKEY } " )
163
168
if inputs == "" :
164
169
inputs = "你好👋"
165
- headers , playload = generate_message (
170
+ headers , payload = generate_message (
166
171
input = inputs ,
167
172
model = llm_kwargs ["llm_model" ],
168
173
key = APIKEY ,
@@ -182,7 +187,7 @@ def predict_no_ui_long_connection(
182
187
endpoint ,
183
188
headers = headers ,
184
189
proxies = None if disable_proxy else proxies ,
185
- json = playload ,
190
+ json = payload ,
186
191
stream = True ,
187
192
timeout = TIMEOUT_SECONDS ,
188
193
)
@@ -198,7 +203,7 @@ def predict_no_ui_long_connection(
198
203
result = ""
199
204
finish_reason = ""
200
205
if reasoning :
201
- resoning_buffer = ""
206
+ reasoning_buffer = ""
202
207
203
208
stream_response = response .iter_lines ()
204
209
while True :
@@ -226,12 +231,12 @@ def predict_no_ui_long_connection(
226
231
if chunk :
227
232
try :
228
233
if finish_reason == "stop" :
229
- if not console_slience :
234
+ if not console_silence :
230
235
print (f"[response] { result } " )
231
236
break
232
237
result += response_text
233
238
if reasoning :
234
- resoning_buffer += reasoning_content
239
+ reasoning_buffer += reasoning_content
235
240
if observe_window is not None :
236
241
# 观测窗,把已经获取的数据显示出去
237
242
if len (observe_window ) >= 1 :
@@ -247,9 +252,9 @@ def predict_no_ui_long_connection(
247
252
logger .error (error_msg )
248
253
raise RuntimeError ("Json解析不合常规" )
249
254
if reasoning :
250
- # reasoning 的部分加上框 (>)
251
- return ' \n ' .join (map ( lambda x : '> ' + x , resoning_buffer .split ('\n ' ))) + \
252
- ' \n \n ' + result
255
+ return f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
256
+ { ' ' .join ([ f'<p style="margin: 1.25em 0;"> { line } </p>' for line in reasoning_buffer .split ('\n ' )]) }
257
+ </div> \n \n '' ' + result
253
258
return result
254
259
255
260
def predict (
@@ -268,7 +273,7 @@ def predict(
268
273
inputs 是本次问询的输入
269
274
top_p, temperature是chatGPT的内部调优参数
270
275
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
271
- chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去 ,可以直接修改对话界面内容
276
+ chatbot 为WebUI中显示的对话列表,修改它,然后yield出去 ,可以直接修改对话界面内容
272
277
additional_fn代表点击的哪个按钮,按钮见functional.py
273
278
"""
274
279
from .bridge_all import model_info
@@ -299,7 +304,7 @@ def predict(
299
304
) # 刷新界面
300
305
time .sleep (2 )
301
306
302
- headers , playload = generate_message (
307
+ headers , payload = generate_message (
303
308
input = inputs ,
304
309
model = llm_kwargs ["llm_model" ],
305
310
key = APIKEY ,
@@ -321,7 +326,7 @@ def predict(
321
326
endpoint ,
322
327
headers = headers ,
323
328
proxies = None if disable_proxy else proxies ,
324
- json = playload ,
329
+ json = payload ,
325
330
stream = True ,
326
331
timeout = TIMEOUT_SECONDS ,
327
332
)
@@ -367,7 +372,7 @@ def predict(
367
372
chunk_decoded = chunk .decode ()
368
373
chatbot [- 1 ] = (
369
374
chatbot [- 1 ][0 ],
370
- "[Local Message] {finish_reason},获得以下报错信息:\n "
375
+ f "[Local Message] { finish_reason } ,获得以下报错信息:\n "
371
376
+ chunk_decoded ,
372
377
)
373
378
yield from update_ui (
@@ -385,7 +390,9 @@ def predict(
385
390
if reasoning :
386
391
gpt_replying_buffer += response_text
387
392
gpt_reasoning_buffer += reasoning_content
388
- history [- 1 ] = '\n ' .join (map (lambda x : '> ' + x , gpt_reasoning_buffer .split ('\n ' ))) + '\n \n ' + gpt_replying_buffer
393
+ history [- 1 ] = f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
394
+ { '' .join ([f'<p style="margin: 1.25em 0;">{ line } </p>' for line in gpt_reasoning_buffer .split ('\n ' )])}
395
+ </div>\n \n ''' + gpt_replying_buffer
389
396
else :
390
397
gpt_replying_buffer += response_text
391
398
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
0 commit comments