11"""
2- baidu voice service
2+ baidu voice service with thread-safe token caching
33"""
44import json
55import os
66import time
7+ import threading
8+ import requests
79
810from aip import AipSpeech
911
1416from voice .audio_convert import get_pcm_from_wav
1517from voice .voice import Voice
1618
17- """
18- 百度的语音识别API.
19- dev_pid:
20- - 1936: 普通话远场
21- - 1536:普通话(支持简单的英文识别)
22- - 1537:普通话(纯中文识别)
23- - 1737:英语
24- - 1637:粤语
25- - 1837:四川话
26- 要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号,
27- 之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key
28- 然后在 config.json 中填入这两个值, 以及 app_id, dev_pid
29- """
30-
31-
3219class BaiduVoice (Voice ):
3320 def __init__ (self ):
3421 try :
22+ # 读取本地 TTS 参数配置
3523 curdir = os .path .dirname (__file__ )
3624 config_path = os .path .join (curdir , "config.json" )
37- bconf = None
38- if not os .path .exists (config_path ): # 如果没有配置文件,创建本地配置文件
25+ if not os .path .exists (config_path ):
3926 bconf = {"lang" : "zh" , "ctp" : 1 , "spd" : 5 , "pit" : 5 , "vol" : 5 , "per" : 0 }
4027 with open (config_path , "w" ) as fw :
4128 json .dump (bconf , fw , indent = 4 )
@@ -47,48 +34,139 @@ def __init__(self):
4734 self .api_key = str (conf ().get ("baidu_api_key" ))
4835 self .secret_key = str (conf ().get ("baidu_secret_key" ))
4936 self .dev_id = conf ().get ("baidu_dev_pid" )
37+
5038 self .lang = bconf ["lang" ]
51- self .ctp = bconf ["ctp" ]
52- self .spd = bconf ["spd" ]
53- self .pit = bconf ["pit" ]
54- self .vol = bconf ["vol" ]
55- self .per = bconf ["per" ]
39+ self .ctp = bconf ["ctp" ]
40+ self .spd = bconf ["spd" ]
41+ self .pit = bconf ["pit" ]
42+ self .vol = bconf ["vol" ]
43+ self .per = bconf ["per" ]
5644
45+ # 百度 SDK 客户端(短文本合成 & 语音识别)
5746 self .client = AipSpeech (self .app_id , self .api_key , self .secret_key )
47+
48+ # access_token 缓存与锁
49+ self ._access_token = None
50+ self ._token_expire_ts = 0
51+ self ._token_lock = threading .Lock ()
5852 except Exception as e :
59- logger .warn ("BaiduVoice init failed: %s, ignore " % e )
53+ logger .warn ("BaiduVoice init failed: %s, ignore" % e )
54+
55+ def _get_access_token (self ):
56+ # 多线程安全获取 token
57+ with self ._token_lock :
58+ now = time .time ()
59+ if self ._access_token and now < self ._token_expire_ts :
60+ return self ._access_token
61+ url = "https://aip.baidubce.com/oauth/2.0/token"
62+ params = {
63+ "grant_type" : "client_credentials" ,
64+ "client_id" : self .api_key ,
65+ "client_secret" : self .secret_key ,
66+ }
67+ resp = requests .post (url , params = params ).json ()
68+ token = resp .get ("access_token" )
69+ expires_in = resp .get ("expires_in" , 2592000 )
70+ if token :
71+ self ._access_token = token
72+ self ._token_expire_ts = now + expires_in - 60 # 提前 1 分钟过期
73+ return token
74+ else :
75+ logger .error ("BaiduVoice _get_access_token failed: %s" , resp )
76+ return None
6077
6178 def voiceToText (self , voice_file ):
62- # 识别本地文件
63- logger .debug ("[Baidu] voice file name={}" .format (voice_file ))
79+ logger .debug ("[Baidu] recognize voice file=%s" , voice_file )
6480 pcm = get_pcm_from_wav (voice_file )
6581 res = self .client .asr (pcm , "pcm" , 16000 , {"dev_pid" : self .dev_id })
66- if res ["err_no" ] == 0 :
67- logger .info ("百度语音识别到了:{}" .format (res ["result" ]))
82+ if res .get ("err_no" ) == 0 :
6883 text = "" .join (res ["result" ])
69- reply = Reply (ReplyType .TEXT , text )
84+ logger .info ("[Baidu] ASR result: %s" , text )
85+ return Reply (ReplyType .TEXT , text )
7086 else :
71- logger .info ("百度语音识别出错了: {}" .format (res ["err_msg" ]))
72- if res ["err_msg" ] == "request pv too much" :
73- logger .info (" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费" )
74- reply = Reply (ReplyType .ERROR , "百度语音识别出错了;{0}" .format (res ["err_msg" ]))
75- return reply
87+ err = res .get ("err_msg" , "" )
88+ logger .error ("[Baidu] ASR error: %s" , err )
89+ return Reply (ReplyType .ERROR , f"语音识别失败:{ err } " )
7690
77- def textToVoice (self , text ):
78- result = self .client .synthesis (
79- text ,
80- self .lang ,
81- self .ctp ,
82- {"spd" : self .spd , "pit" : self .pit , "vol" : self .vol , "per" : self .per },
83- )
84- if not isinstance (result , dict ):
85- # Avoid the same filename under multithreading
86- fileName = TmpDir ().path () + "reply-" + str (int (time .time ())) + "-" + str (hash (text ) & 0x7FFFFFFF ) + ".mp3"
87- with open (fileName , "wb" ) as f :
88- f .write (result )
89- logger .info ("[Baidu] textToVoice text={} voice file name={}" .format (text , fileName ))
90- reply = Reply (ReplyType .VOICE , fileName )
91+ def _long_text_synthesis (self , text ):
92+ token = self ._get_access_token ()
93+ if not token :
94+ return Reply (ReplyType .ERROR , "获取百度 access_token 失败" )
95+
96+ # 创建合成任务
97+ create_url = f"https://aip.baidubce.com/rpc/2.0/tts/v1/create?access_token={ token } "
98+ payload = {
99+ "text" : text ,
100+ "format" : "mp3-16k" ,
101+ "voice" : 0 ,
102+ "lang" : self .lang ,
103+ "speed" : self .spd ,
104+ "pitch" : self .pit ,
105+ "volume" : self .vol ,
106+ "enable_subtitle" : 0 ,
107+ }
108+ headers = {"Content-Type" : "application/json" }
109+ create_resp = requests .post (create_url , headers = headers , json = payload ).json ()
110+ task_id = create_resp .get ("task_id" )
111+ if not task_id :
112+ logger .error ("[Baidu] 长文本合成创建任务失败: %s" , create_resp )
113+ return Reply (ReplyType .ERROR , "长文本合成任务提交失败" )
114+ logger .info ("[Baidu] 长文本合成任务已提交 task_id=%s" , task_id )
115+
116+ # 轮询查询任务状态
117+ query_url = f"https://aip.baidubce.com/rpc/2.0/tts/v1/query?access_token={ token } "
118+ for _ in range (100 ):
119+ time .sleep (3 )
120+ resp = requests .post (query_url , headers = headers , json = {"task_ids" :[task_id ]})
121+ result = resp .json ()
122+ infos = result .get ("tasks_info" ) or result .get ("tasks" ) or []
123+ if not infos :
124+ continue
125+ info = infos [0 ]
126+ status = info .get ("task_status" )
127+ if status == "Success" :
128+ task_res = info .get ("task_result" , {})
129+ audio_url = task_res .get ("audio_address" ) or task_res .get ("speech_url" )
130+ break
131+ elif status == "Running" :
132+ continue
133+ else :
134+ logger .error ("[Baidu] 长文本合成失败: %s" , info )
135+ return Reply (ReplyType .ERROR , "长文本合成执行失败" )
91136 else :
92- logger .error ("[Baidu] textToVoice error={}" .format (result ))
93- reply = Reply (ReplyType .ERROR , "抱歉,语音合成失败" )
94- return reply
137+ return Reply (ReplyType .ERROR , "长文本合成超时,请稍后重试" )
138+
139+ # 下载并保存音频
140+ audio_data = requests .get (audio_url ).content
141+ fn = TmpDir ().path () + f"reply-long-{ int (time .time ())} -{ hash (text )& 0x7FFFFFFF } .mp3"
142+ with open (fn , "wb" ) as f :
143+ f .write (audio_data )
144+ logger .info ("[Baidu] 长文本合成 success: %s" , fn )
145+ return Reply (ReplyType .VOICE , fn )
146+
147+ def textToVoice (self , text ):
148+ try :
149+ # GBK 编码字节长度
150+ gbk_len = len (text .encode ("gbk" , errors = "ignore" ))
151+ if gbk_len <= 1024 :
152+ # 短文本走 SDK 合成
153+ result = self .client .synthesis (
154+ text , self .lang , self .ctp ,
155+ {"spd" :self .spd , "pit" :self .pit , "vol" :self .vol , "per" :self .per }
156+ )
157+ if not isinstance (result , dict ):
158+ fn = TmpDir ().path () + f"reply-{ int (time .time ())} -{ hash (text )& 0x7FFFFFFF } .mp3"
159+ with open (fn , "wb" ) as f :
160+ f .write (result )
161+ logger .info ("[Baidu] 短文本合成 success: %s" , fn )
162+ return Reply (ReplyType .VOICE , fn )
163+ else :
164+ logger .error ("[Baidu] 短文本合成 error: %s" , result )
165+ return Reply (ReplyType .ERROR , "短文本语音合成失败" )
166+ else :
167+ # 长文本
168+ return self ._long_text_synthesis (text )
169+ except Exception as e :
170+ logger .error ("BaiduVoice textToVoice exception: %s" , e )
171+ return Reply (ReplyType .ERROR , f"合成异常:{ e } " )
172+
0 commit comments