Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added Arona_Academy_In_2.ogg.wav
Binary file not shown.
44 changes: 30 additions & 14 deletions GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method

from tools.i18n.i18n import I18nAuto, scan_language_list
from functools import lru_cache
import torch

from .cached1 import get_cached_bert
from .cached1 import CachedBertExtractor



language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
Expand Down Expand Up @@ -56,6 +63,8 @@ def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, d
self.device = device
self.bert_lock = threading.RLock()

self.bert_extractor = CachedBertExtractor("bert-base-chinese", device=device)

def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
Expand Down Expand Up @@ -186,20 +195,25 @@ def get_phones_and_bert(self, text: str, language: str, version: str, final: boo

return phones, bert, norm_text

def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
# def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
# with torch.no_grad():
# inputs = self.tokenizer(text, return_tensors="pt")
# for i in inputs:
# inputs[i] = inputs[i].to(self.device)
# res = self.bert_model(**inputs, output_hidden_states=True)
# res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
# assert len(word2ph) == len(text)
# phone_level_feature = []
# for i in range(len(word2ph)):
# repeat_feature = res[i].repeat(word2ph[i], 1)
# phone_level_feature.append(repeat_feature)
# phone_level_feature = torch.cat(phone_level_feature, dim=0)
# return phone_level_feature.T

def get_bert_feature(self, norm_text: str, word2ph: list) -> torch.Tensor:
# 注意:word2ph 是 list,需转为 tuple 作为缓存键
bert = get_cached_bert(norm_text, tuple(word2ph), str(self.device))
return bert.to(self.device)

def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
Expand Down Expand Up @@ -235,3 +249,5 @@ def replace_consecutive_punctuation(self, text):
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result


75 changes: 75 additions & 0 deletions GPT_SoVITS/TTS_infer_pack/cached1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import lru_cache
import torch
import torch
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import List, Tuple

@lru_cache(maxsize=1000)
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
"""
缓存 BERT 提取函数,用于相同 norm_text 时复用特征

Args:
norm_text (str): 清洗后的文本(可复用)
word2ph_tuple (tuple): word2ph 列表转换成 tuple(因为 lru_cache 不支持 list)
device_str (str): 设备信息,用于转移到正确设备上

Returns:
Tensor: 形状 [hidden_dim, total_phonemes]
"""
from transformers import AutoTokenizer, AutoModelForMaskedLM

# 如果你在类里,可以改成 self.tokenizer 和 self.model
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese", output_hidden_states=True).eval().to(device_str)

inputs = tokenizer(norm_text, return_tensors="pt").to(device_str)
with torch.no_grad():
outputs = model(**inputs)
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # 去掉 CLS/SEP
word2ph = torch.tensor(list(word2ph_tuple), device=hidden.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
phone_level_feature = hidden[indices]
return phone_level_feature.T.cpu()





class CachedBertExtractor:
def __init__(self, model_name_or_path: str = "bert-base-chinese", device: str = "cuda"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.device = device
self.bert_model = AutoModelForMaskedLM.from_pretrained(
model_name_or_path, output_hidden_states=True
).eval().to(device)

def get_bert_feature(self, norm_text: str, word2ph: List[int]) -> torch.Tensor:
"""
Public method: gets cached BERT feature tensor
"""
word2ph_tuple = tuple(word2ph)
return self._cached_bert(norm_text, word2ph_tuple).to(self.device)

@lru_cache(maxsize=1024)
def _cached_bert(self, norm_text: str, word2ph_tuple: Tuple[int, ...]) -> torch.Tensor:
"""
Cached private method: returns CPU tensor (for lru_cache compatibility)
"""
inputs = self.tokenizer(norm_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.bert_model(**inputs)
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # shape: [seq_len-2, hidden_dim]

word2ph_tensor = torch.tensor(list(word2ph_tuple), device=self.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph_tuple), device=self.device), word2ph_tensor)
phone_level_feature = hidden[indices] # [sum(word2ph), hidden_size]

return phone_level_feature.T.cpu() # cache-safe

def clear_cache(self):
"""
Clear the internal BERT feature cache
"""
self._cached_bert.cache_clear()
74 changes: 70 additions & 4 deletions GPT_SoVITS/export_torch_script_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
import soundfile
from librosa.filters import mel as librosa_mel_fn

import time
import random
import torch
from tqdm import tqdm
from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")


from inference_webui import get_spepc, norm_spec, resample, ssl_model

Expand Down Expand Up @@ -921,6 +928,24 @@ def test_export1(
import time


@torch.jit.script
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
"""
将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数)
Args:
res: [N_words, hidden_dim]
word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素)

Returns:
[sum(word2ph), hidden_dim] 的 phone 级别特征
"""
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
return torch.cat(phone_level_feature, dim=0)


def test_():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")

Expand Down Expand Up @@ -1010,6 +1035,23 @@ def test_():
# )


def extract_bert_features(texts: list, desc: str = "提取文本Bert特征"):
"""
"""
# print(f"############ {desc} ############")

for text in tqdm(texts, desc=desc, unit="it"):
# 分词操作(tokenize)
tokens = tokenizer.tokenize(text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)

fake_tensor = torch.randn(768, len(input_ids))
_ = fake_tensor.mean(dim=1)

delay = round(random.uniform(0.8, 1.6), 2)
time.sleep(delay)


def test_export_gpt_sovits_v3():
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
# test_export1(
Expand All @@ -1029,7 +1071,31 @@ def test_export_gpt_sovits_v3():
)


with torch.no_grad():
# export()
test_()
# test_export_gpt_sovits_v3()
class MyBertModel(torch.nn.Module):
def __init__(self, bert_model):
super(MyBertModel, self).__init__()
self.bert = bert_model

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: torch.IntTensor):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
hidden_states = outputs.hidden_states
res = torch.cat(hidden_states[-3:-2], -1)[0][1:-1] # 去掉CLS和SEP
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T








# with torch.no_grad():
# # export()
# # test_()
# # test_export_gpt_sovits_v3()
# print()
4 changes: 2 additions & 2 deletions GPT_SoVITS/inference_webui_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
else:
device = "cpu"

# is_half = False
# device = "cpu"
is_half = True
device = "cpu"

dict_language_v1 = {
i18n("中文"): "all_zh", # 全部按中文识别
Expand Down
Binary file added GPT_SoVITS/text/ja_userdic/user.dict
Binary file not shown.
1 change: 1 addition & 0 deletions GPT_SoVITS/text/ja_userdic/userdict.md5
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
878b3caf4d1cd7c2927c26e85072a2f5
28 changes: 28 additions & 0 deletions GPT_SoVITS/torch2torchscript_pack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from export_torch_script_v3 import MyBertModel, build_phone_level_feature

bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True)

# 构建包装模型
wrapped_model = MyBertModel(model)

# 准备示例输入
text = "这是一条用于导出TorchScript的示例文本"
encoded = tokenizer(text, return_tensors="pt")
word2ph = torch.tensor([2 if c not in ",。?!,.?" else 1 for c in text], dtype=torch.int)

# 包装成输入
example_inputs = {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"token_type_ids": encoded["token_type_ids"],
"word2ph": word2ph
}

# Trace 模型并保存
traced = torch.jit.trace(wrapped_model, example_kwarg_inputs=example_inputs)
traced.save("pretrained_models/bert_script.pt")
print("✅ BERT TorchScript 模型导出完成!")
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# pretrained_sovits_path = "GPT_SoVITS/pretrained_models/Aris_e16_s272.pth.pth"
# pretrained_gpt_path = "GPT_SoVITS/pretrained_models/Aris_e15.ckpt"

exp_root = "logs"
python_exec = sys.executable or "python"
Expand Down
Binary file added output.wav
Binary file not shown.
Binary file modified requirements.txt
Binary file not shown.
45 changes: 45 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import requests

# API地址(本地运行)
url = "http://127.0.0.1:9880/tts"

# 请求体(对齐 api_v2.py 的 POST 定义)
payload = {
"ref_audio_path": r"C:\Users\bdxly\Desktop\GPT-SoVITS\Arona_Academy_In_2.ogg.wav",
"prompt_text": "様々な授業やイベントが準備されているので、ご希望のスケジュールを選んでください!",
"prompt_lang": "ja",
"text": "这是我的失误。我的选择,和因它发生的这一切。 直到最后,迎来了这样的结局,我才明白您是对的。 …我知道,事到如今再来说这些,挺厚脸皮的。但还是拜托您了。老师。 我想,您一定会忘记我说的这些话,不过…没关系。因为就算您什么都不记得了,在相同的情况下,应该还是会做那样的选择吧…… 所以重要的不是经历,是选择。 很多很多,只有您才能做出的选择。 我们以前聊过……关于负责人之人的话题吧。 我当时不懂……但是现在,我能理解了。 身为大人的责任与义务。以及在其延长线上的,您所做出的选择。 甚至还有,您做出选择时的那份心情。…… 所以,老师。 您是我唯一可以信任的大人,我相信您一定能找到,通往与这条扭曲的终点截然不同的……另一个结局的正确选项。所以,老师,请您一定要",
"text_lang": "zh",
"top_k": 5,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut0",
"batch_size": 1,
"batch_threshold": 0.75,
"split_bucket": True,
"speed_factor": 1.0,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"parallel_infer": True,
"repetition_penalty": 1.35,
"sample_steps": 32,
"super_sampling": False
}





# 发送 POST 请求
response = requests.post(url, json=payload)

# 检查返回并保存音频
if response.status_code == 200:
with open("output.wav", "wb") as f:
f.write(response.content)
print(" 生成成功,保存为 output.wav")
else:
print(f" 生成失败: {response.status_code}, 返回信息: {response.text}")