Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self,
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.flow_cache_dict = {}
self.mel_overlap_dict = {}
self.hift_cache_dict = {}

Expand Down Expand Up @@ -92,13 +93,18 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
self.llm_end_dict[uuid] = True

def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
tts_mel = self.flow.inference(token=token.to(self.device),
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device))
embedding=embedding.to(self.device),
required_cache_size=self.mel_overlap_len,
flow_cache=self.flow_cache_dict[uuid])
self.flow_cache_dict[uuid] = flow_cache
tts_mel = tts_mel.float()

# mel overlap fade in out
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
Expand Down Expand Up @@ -128,6 +134,7 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.flow_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
Expand Down
12 changes: 8 additions & 4 deletions cosyvoice/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def inference(self,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding):
embedding,
required_cache_size=0,
flow_cache=None):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
Expand All @@ -134,13 +136,15 @@ def inference(self,

# mask = (~make_pad_mask(feat_len)).to(h)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat = self.decoder(
feat, flow_cache = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10
n_timesteps=10,
required_cache_size=required_cache_size,
flow_cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat
return feat, flow_cache
21 changes: 18 additions & 3 deletions cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
self.estimator = estimator

@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
"""Forward diffusion

Args:
Expand All @@ -50,11 +50,26 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature

if flow_cache is not None:
z_cache = flow_cache[0]
mu_cache = flow_cache[1]
z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature
z = torch.cat((z_cache, z), dim=2) # [B, 80, T]
mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T]
else:
z = torch.randn_like(mu) * temperature

next_cache_start = max(z.size(2) - required_cache_size, 0)
flow_cache = [
z[..., next_cache_start:],
mu[..., next_cache_start:]
]

t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache

def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Expand Down