1
1
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
2
# reference: https://github.com/lifeiteng/vall-e
3
- import os
4
- import sys
5
3
6
- now_dir = os .getcwd ()
7
- sys .path .append (now_dir )
8
- from typing import Dict
9
4
10
- import torch
11
5
from pytorch_lightning import LightningModule
12
6
13
- from AR .models .t2s_model_onnx import Text2SemanticDecoder
14
- from AR .modules .lr_schedulers import WarmupCosineLRSchedule
15
- from AR .modules .optim import ScaledAdam
7
+ from .t2s_model_onnx import Text2SemanticDecoder
16
8
17
9
18
10
class Text2SemanticLightningModule (LightningModule ):
@@ -21,90 +13,3 @@ def __init__(self, config, output_dir, is_train=True):
21
13
self .config = config
22
14
self .top_k = 3
23
15
self .model = Text2SemanticDecoder (config = config , top_k = self .top_k )
24
- pretrained_s1 = config .get ("pretrained_s1" )
25
- if pretrained_s1 and is_train :
26
- # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
27
- print (
28
- self .load_state_dict (
29
- torch .load (
30
- pretrained_s1 ,
31
- map_location = "cpu" ,
32
- )["weight" ],
33
- ),
34
- )
35
- if is_train :
36
- self .automatic_optimization = False
37
- self .save_hyperparameters ()
38
- self .eval_dir = output_dir / "eval"
39
- self .eval_dir .mkdir (parents = True , exist_ok = True )
40
-
41
- def training_step (self , batch : Dict , batch_idx : int ):
42
- opt = self .optimizers ()
43
- scheduler = self .lr_schedulers ()
44
- loss , acc = self .model .forward (
45
- batch ["phoneme_ids" ],
46
- batch ["phoneme_ids_len" ],
47
- batch ["semantic_ids" ],
48
- batch ["semantic_ids_len" ],
49
- batch ["bert_feature" ],
50
- )
51
- self .manual_backward (loss )
52
- if batch_idx > 0 and batch_idx % 4 == 0 :
53
- opt .step ()
54
- opt .zero_grad ()
55
- scheduler .step ()
56
-
57
- self .log (
58
- "total_loss" ,
59
- loss ,
60
- on_step = True ,
61
- on_epoch = True ,
62
- prog_bar = True ,
63
- sync_dist = True ,
64
- )
65
- self .log (
66
- "lr" ,
67
- scheduler .get_last_lr ()[0 ],
68
- on_epoch = True ,
69
- prog_bar = True ,
70
- sync_dist = True ,
71
- )
72
- self .log (
73
- f"top_{ self .top_k } _acc" ,
74
- acc ,
75
- on_step = True ,
76
- on_epoch = True ,
77
- prog_bar = True ,
78
- sync_dist = True ,
79
- )
80
-
81
- def validation_step (self , batch : Dict , batch_idx : int ):
82
- return
83
-
84
- def configure_optimizers (self ):
85
- model_parameters = self .model .parameters ()
86
- parameters_names = []
87
- parameters_names .append ([name_param_pair [0 ] for name_param_pair in self .model .named_parameters ()])
88
- lm_opt = ScaledAdam (
89
- model_parameters ,
90
- lr = 0.01 ,
91
- betas = (0.9 , 0.95 ),
92
- clipping_scale = 2.0 ,
93
- parameters_names = parameters_names ,
94
- show_dominant_parameters = False ,
95
- clipping_update_period = 1000 ,
96
- )
97
-
98
- return {
99
- "optimizer" : lm_opt ,
100
- "lr_scheduler" : {
101
- "scheduler" : WarmupCosineLRSchedule (
102
- lm_opt ,
103
- init_lr = self .config ["optimizer" ]["lr_init" ],
104
- peak_lr = self .config ["optimizer" ]["lr" ],
105
- end_lr = self .config ["optimizer" ]["lr_end" ],
106
- warmup_steps = self .config ["optimizer" ]["warmup_steps" ],
107
- total_steps = self .config ["optimizer" ]["decay_steps" ],
108
- )
109
- },
110
- }
0 commit comments