99from torchmetrics .classification import MulticlassAccuracy
1010from tqdm import tqdm
1111
12- from AR .models .utils import (
12+ from GPT_SoVITS . AR .models .utils import (
1313 dpo_loss ,
1414 get_batch_logps ,
1515 make_pad_mask ,
1818 sample ,
1919 topk_sampling ,
2020)
21- from AR .modules .embedding import SinePositionalEmbedding , TokenEmbedding
22- from AR .modules .transformer import LayerNorm , TransformerEncoder , TransformerEncoderLayer
21+ from GPT_SoVITS . AR .modules .embedding import SinePositionalEmbedding , TokenEmbedding
22+ from GPT_SoVITS . AR .modules .transformer import LayerNorm , TransformerEncoder , TransformerEncoderLayer
2323
2424default_config = {
2525 "embedding_dim" : 512 ,
@@ -420,7 +420,7 @@ def forward(self, x, x_lens, y, y_lens, bert_feature):
420420 mask = xy_attn_mask ,
421421 )
422422 x_len = x_lens .max ()
423- logits = self .ar_predict_layer (xy_dec [:, x_len - 1 :])
423+ logits = self .ar_predict_layer (xy_dec [:, x_len - 1 :])
424424
425425 ###### DPO #############
426426 reject_xy_pos , reject_xy_attn_mask , reject_targets = self .make_input_data (
@@ -432,7 +432,7 @@ def forward(self, x, x_lens, y, y_lens, bert_feature):
432432 mask = reject_xy_attn_mask ,
433433 )
434434 x_len = x_lens .max ()
435- reject_logits = self .ar_predict_layer (reject_xy_dec [:, x_len - 1 :])
435+ reject_logits = self .ar_predict_layer (reject_xy_dec [:, x_len - 1 :])
436436
437437 # loss
438438 # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
@@ -502,7 +502,7 @@ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
502502 (xy_pos , None ),
503503 mask = xy_attn_mask ,
504504 )
505- logits = self .ar_predict_layer (xy_dec [:, x_len - 1 :]).permute (0 , 2 , 1 )
505+ logits = self .ar_predict_layer (xy_dec [:, x_len - 1 :]).permute (0 , 2 , 1 )
506506 # loss
507507 # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
508508 loss = F .cross_entropy (logits , targets , reduction = "sum" )
@@ -724,8 +724,8 @@ def infer_panel_batch_infer(
724724 l1 = samples [:, 0 ] == self .EOS
725725 l2 = tokens == self .EOS
726726 l = l1 .logical_or (l2 )
727- removed_idx_of_batch_for_y = torch .where (l == True )[0 ].tolist ()
728- reserved_idx_of_batch_for_y = torch .where (l == False )[0 ]
727+ removed_idx_of_batch_for_y = torch .where (l is True )[0 ].tolist ()
728+ reserved_idx_of_batch_for_y = torch .where (l is False )[0 ]
729729 # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
730730 for i in removed_idx_of_batch_for_y :
731731 batch_index = batch_idx_map [i ]
0 commit comments