Skip to content

Commit 195bd8b

Browse files
committed
add training scripts to train sd
1 parent ba4cedf commit 195bd8b

File tree

3 files changed

+456
-8
lines changed

3 files changed

+456
-8
lines changed

mugen/trainer.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
from __future__ import annotations
2+
from typing import Iterable, TYPE_CHECKING
3+
import os
4+
import math
5+
from tqdm import tqdm
6+
7+
import torch
8+
import accelerate
9+
import diffusers
10+
11+
from torch.utils.data import DataLoader, Dataset
12+
from torch.nn import Parameter
13+
from accelerate.logging import get_logger
14+
from diffusers.optimization import get_scheduler
15+
16+
from mugen.utils.trainer_utils import set_seed, get_last_checkpoint, prune_checkpoints
17+
18+
if TYPE_CHECKING:
19+
from mugen import TrainingArguments
20+
from mugen.trainingmodules import TrainingModule
21+
from torch.optim import Optimizer
22+
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
23+
24+
25+
logger = get_logger(__name__, log_level="INFO")
26+
27+
28+
class Trainer:
29+
def __init__(
30+
self,
31+
project_name: str,
32+
training_module: TrainingModule,
33+
training_args: TrainingArguments,
34+
train_dataset: Dataset,
35+
eval_dataset: Dataset,
36+
):
37+
self.training_args = training_args
38+
self.global_step = 0
39+
40+
set_seed(self.training_args.seed)
41+
42+
self.accelerator = accelerate.Accelerator(
43+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
44+
mixed_precision=training_args.mixed_precision,
45+
log_with=training_args.logger,
46+
cpu=training_args.use_cpu,
47+
deepspeed_plugin=training_args.get_deepspeed_plugin(),
48+
fsdp_plugin=training_args.get_fsdp_plugin(),
49+
project_config=training_args.get_project_configuration(),
50+
)
51+
52+
if self.accelerator.is_local_main_process:
53+
diffusers.utils.logging.set_verbosity_info()
54+
else:
55+
diffusers.utils.logging.set_verbosity_error()
56+
57+
self.accelerator.register_save_state_pre_hook(training_module.save_model_hook)
58+
self.accelerator.register_load_state_pre_hook(training_module.load_model_hook)
59+
60+
self.training_module = training_module
61+
self.training_module.register_trainer(self)
62+
63+
self.train_dataloader = self.get_train_dataloader(train_dataset)
64+
self.val_dataloader = self.get_eval_dataloader(eval_dataset)
65+
66+
self.optimizers = [
67+
self.create_optimizer(params)
68+
for params in self.training_module.get_optim_params()
69+
]
70+
71+
num_training_steps = len(self.train_dataloader) * self.training_args.num_epochs
72+
self.schedulers = [
73+
self.create_scheduler(
74+
opt,
75+
num_training_steps=num_training_steps,
76+
num_warmup_steps=self.training_args.get_warmup_steps(
77+
num_training_steps
78+
),
79+
)
80+
for opt in self.optimizers
81+
]
82+
83+
# Prepare with Accelerator
84+
self.training_module = self.accelerator.prepare_model(self.training_module)
85+
for i in range(len(self.optimizers)):
86+
self.optimizers[i] = self.accelerator.prepare_optimizer(self.optimizers[i])
87+
for i in range(len(self.schedulers)):
88+
self.schedulers[i] = self.accelerator.prepare_scheduler(self.schedulers[i])
89+
self.train_dataloader = self.accelerator.prepare_data_loader(
90+
self.train_dataloader
91+
)
92+
self.val_dataloader = self.accelerator.prepare_data_loader(self.val_dataloader)
93+
94+
if self.accelerator.is_main_process:
95+
self.accelerator.init_trackers(
96+
project_name,
97+
init_kwargs={
98+
self.training_args.logger: self.training_args.tracker_init_kwargs
99+
},
100+
)
101+
102+
def start(self):
103+
total_batch_size = (
104+
self.training_args.train_batch_size
105+
* self.accelerator.num_processes
106+
* self.training_args.gradient_accumulation_steps
107+
)
108+
num_update_steps_per_epoch = math.ceil(
109+
len(self.train_dataloader) / self.training_args.gradient_accumulation_steps
110+
)
111+
max_train_steps = self.training_args.num_epochs * num_update_steps_per_epoch
112+
113+
logger.info("***** Running training *****")
114+
logger.info(f" Num examples = {len(self.train_dataloader.dataset)}")
115+
logger.info(f" Num Epochs = {self.training_args.num_epochs}")
116+
logger.info(
117+
f" Instantaneous batch size per device = {self.training_args.train_batch_size}"
118+
)
119+
logger.info(
120+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
121+
)
122+
logger.info(
123+
f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}"
124+
)
125+
logger.info(f" Total optimization steps = {max_train_steps}")
126+
127+
first_epoch = 0
128+
129+
if self.training_args.resume_from_checkpoint:
130+
if self.training_args.resume_from_checkpoint == "latest":
131+
path = get_last_checkpoint(self.training_args.output_dir)
132+
else:
133+
path = self.training_args.resume_from_checkpoint
134+
135+
if path is None or not os.path.exists(path):
136+
self.accelerator.print(
137+
f"Checkpoint not found at {path}. Starting a new training run."
138+
)
139+
self.training_args.resume_from_checkpoint = None
140+
else:
141+
self.accelerator.print(f"Loading checkpoint from {path}")
142+
self.accelerator.load_state(path)
143+
144+
self.global_step = int(os.path.basename(path).split("-")[-1])
145+
146+
resume_global_step = (
147+
self.global_step * self.training_args.gradient_accumulation_steps
148+
)
149+
first_epoch = self.global_step // num_update_steps_per_epoch
150+
resume_step = resume_global_step % (
151+
num_update_steps_per_epoch
152+
* self.training_args.gradient_accumulation_steps
153+
)
154+
155+
# Train!
156+
self.training_module.on_start()
157+
for epoch in range(first_epoch, self.training_args.num_epochs):
158+
with tqdm(
159+
total=num_update_steps_per_epoch,
160+
disable=not self.accelerator.is_local_main_process,
161+
) as progress_bar:
162+
self.training_module.register_progress_bar(progress_bar)
163+
progress_bar.set_description(f"Epoch {epoch}")
164+
165+
self.training_module.train()
166+
self.training_module.on_train_epoch_start()
167+
for step, batch in enumerate(self.train_dataloader):
168+
# Skip steps until we reach the resumed step
169+
if (
170+
self.training_args.resume_from_checkpoint
171+
and epoch == first_epoch
172+
and step < resume_step
173+
):
174+
if step % self.training_args.gradient_accumulation_steps == 0:
175+
progress_bar.update(1)
176+
continue
177+
178+
self.training_module.on_train_batch_start()
179+
180+
with self.accelerator.accumulate(self.training_module):
181+
self.training_module.training_step(batch, self.optimizers, step)
182+
for scheduler in self.schedulers:
183+
scheduler.step()
184+
185+
if self.accelerator.sync_gradients:
186+
self.training_module.on_train_batch_end()
187+
progress_bar.update(1)
188+
189+
self.global_step += 1
190+
191+
if self.global_step % self.training_args.save_steps == 0:
192+
if self.accelerator.is_main_process:
193+
prune_checkpoints(self.training_args.output_dir, self.training_args.save_total_limit - 1)
194+
save_path = os.path.join(
195+
self.training_args.output_dir,
196+
f"checkpoint-{self.global_step}",
197+
)
198+
self.accelerator.save_state(save_path)
199+
logger.info(f"Saved state to {save_path}")
200+
201+
if (
202+
self.global_step
203+
% self.training_args.get_eval_steps(max_train_steps)
204+
== 0
205+
):
206+
self._eval_loop()
207+
208+
if self.accelerator.is_main_process:
209+
self.training_module.on_train_epoch_end()
210+
211+
self.accelerator.wait_for_everyone()
212+
self.accelerator.end_training()
213+
214+
def _eval_loop(self):
215+
with tqdm(
216+
total=len(self.val_dataloader),
217+
disable=not self.accelerator.is_local_main_process,
218+
) as progress_bar:
219+
progress_bar.set_description(f"Evaluating...")
220+
221+
self.training_module.eval()
222+
with torch.inference_mode():
223+
self.training_module.on_validation_epoch_start()
224+
for step, batch in enumerate(self.val_dataloader):
225+
self.training_module.validation_step(batch, step)
226+
progress_bar.update(1)
227+
228+
if self.accelerator.is_main_process:
229+
self.training_module.on_validation_epoch_end()
230+
231+
def evaluate(self):
232+
self._eval_loop()
233+
234+
def get_tracker(self, unwrap: bool = False):
235+
return self.accelerator.get_tracker(self.training_args.logger, unwrap)
236+
237+
def create_optimizer(self, parameters: Iterable[Parameter]):
238+
return torch.optim.AdamW(
239+
parameters,
240+
lr=self.training_args.learning_rate,
241+
betas=(self.training_args.adam_beta1, self.training_args.adam_beta2),
242+
eps=self.training_args.adam_epsilon,
243+
weight_decay=self.training_args.adam_weight_decay,
244+
)
245+
246+
def create_scheduler(
247+
self, optimizer: Optimizer, num_training_steps: int, num_warmup_steps: int
248+
) -> LRScheduler:
249+
return get_scheduler(
250+
self.training_args.lr_scheduler_type,
251+
optimizer,
252+
num_warmup_steps=num_warmup_steps,
253+
num_training_steps=num_training_steps,
254+
)
255+
256+
def get_train_dataloader(self, dataset: Dataset):
257+
if self.training_args.data_seed is not None:
258+
generator = torch.Generator().seed(self.training_args.data_seed)
259+
else:
260+
generator = None
261+
262+
return DataLoader(
263+
dataset,
264+
batch_size=self.training_args.train_batch_size,
265+
num_workers=self.training_args.data_loader_num_workers,
266+
generator=generator,
267+
shuffle=True,
268+
)
269+
270+
def get_eval_dataloader(self, dataset: Dataset):
271+
return DataLoader(
272+
dataset,
273+
batch_size=self.training_args.eval_batch_size,
274+
num_workers=self.training_args.data_loader_num_workers,
275+
shuffle=False,
276+
)

0 commit comments

Comments
 (0)