diff --git a/gradio_tabs/dataset.py b/gradio_tabs/dataset.py index 21b1063b5..475bf98f2 100644 --- a/gradio_tabs/dataset.py +++ b/gradio_tabs/dataset.py @@ -3,6 +3,8 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.subprocess import run_script_with_log +import argparse +import transcribe def do_slice( model_name: str, @@ -39,18 +41,34 @@ def do_slice( def do_transcribe( model_name, - whisper_model, + model, compute_type, language, initial_prompt, device, + device_indexes, use_hf_whisper, batch_size, num_beams, + no_repeat_ngram_size: int = 10, ): if model_name == "": return "Error: モデル名を入力してください。" + success, message = transcribe.run( + model_name, + model, + compute_type, + language, + initial_prompt, + device, + device_indexes, + use_hf_whisper, + batch_size, + num_beams, + no_repeat_ngram_size, + ) + ''' cmd = [ "transcribe.py", "--model_name", @@ -72,9 +90,19 @@ def do_transcribe( cmd.append("--use_hf_whisper") cmd.extend(["--batch_size", str(batch_size)]) success, message = run_script_with_log(cmd) + ''' if not success: return f"Error: {message}. エラーメッセージが空の場合、何も問題がない可能性があるので、書き起こしファイルをチェックして問題なければ無視してください。" +import torch + +# 使用可能なGPUのインデックスリスト取得 +def get_gpu_indexes(): + gpu_indexes = [] + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + gpu_indexes.append(str(i)) + return ','.join(gpu_indexes) how_to_md = """ Style-Bert-VITS2の学習用データセットを作成するためのツールです。以下の2つからなります。 @@ -192,6 +220,11 @@ def create_dataset_app() -> gr.Blocks: visible=False, ) device = gr.Radio(["cuda", "cpu"], label="デバイス", value="cuda") + device_indexes = gr.Textbox( + label="使用GPUインデックス", + value=get_gpu_indexes(), + info="使用するGPUインデックスをカンマ区切りで指定、例文(0,1,2)", + ) language = gr.Dropdown(["ja", "en", "zh"], value="ja", label="言語") initial_prompt = gr.Textbox( label="初期プロンプト", @@ -229,6 +262,7 @@ def create_dataset_app() -> gr.Blocks: language, initial_prompt, device, + device_indexes, use_hf_whisper, batch_size, num_beams, diff --git a/requirements.txt b/requirements.txt index 114e97202..ce89f119a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ tensorboard torch>=2.1 transformers umap-learn +portalocker diff --git a/transcribe.py b/transcribe.py index 210ce2c7e..2e35a49a3 100644 --- a/transcribe.py +++ b/transcribe.py @@ -12,6 +12,54 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +from threading import Thread +import queue +import time +# 使用可能なGPUインデックスを入れるキュー +device_queue = queue.Queue() + +def transcribe_thread( + model, + device_index, + wav_file, + output_file, + model_name, + language_id, + initial_prompt, + language, + num_beams, + no_repeat_ngram_size +): + + text = transcribe_with_faster_whisper( + model=model, + audio_file=wav_file, + initial_prompt=initial_prompt, + language=language, + num_beams=num_beams, + no_repeat_ngram_size=no_repeat_ngram_size, + ) + + with open(output_file, "a", encoding="utf-8") as f: + if lock_file(f): + f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") + unlock_file(f) + + device_queue.put(device_index) + +import portalocker + +def lock_file(file_obj): + try: + # ファイルに排他ロックを設定する + portalocker.lock(file_obj, portalocker.LOCK_EX) + return True + except portalocker.LockException: + return False + +def unlock_file(file_obj): + # ファイルのロックを解除する + portalocker.unlock(file_obj) # faster-whisperは並列処理しても速度が向上しないので、単一モデルでループ処理する def transcribe_with_faster_whisper( @@ -103,6 +151,150 @@ def transcribe_files_with_hf_whisper( return results +def run( + model_name:str, + model:str="large-v3", + compute_type:str="bfloat16", + language:str="ja", + initial_prompt:str="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!", + device:str="cuda", + device_indexes:str="0", + use_hf_whisper=True, + batch_size:int=16, + num_beams:int=1, + no_repeat_ngram_size:int=10, +): + + with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: + path_config: dict[str, str] = yaml.safe_load(f.read()) + dataset_root = Path(path_config["dataset_root"]) + + model_name = str(model_name) + + input_dir = dataset_root / model_name / "raw" + output_file = dataset_root / model_name / "esd.list" + initial_prompt: str = initial_prompt + initial_prompt = initial_prompt.strip('"') + language: str = language + device: str = device + # GPUインデックスリスト + device_indexes = [int(x) for x in device_indexes.split(',')] + compute_type: str = compute_type + batch_size: int = batch_size + num_beams: int = num_beams + no_repeat_ngram_size: int = no_repeat_ngram_size + + output_file.parent.mkdir(parents=True, exist_ok=True) + + wav_files = [f for f in input_dir.rglob("*.wav") if f.is_file()] + wav_files = sorted(wav_files, key=lambda x: x.name) + + if output_file.exists(): + logger.warning(f"{output_file} exists, backing up to {output_file}.bak") + backup_path = output_file.with_name(output_file.name + ".bak") + if backup_path.exists(): + logger.warning(f"{output_file}.bak exists, deleting...") + backup_path.unlink() + output_file.rename(backup_path) + + if language == "ja": + language_id = Languages.JP.value + elif language == "en": + language_id = Languages.EN.value + elif language == "zh": + language_id = Languages.ZH.value + else: + raise ValueError(f"{language} is not supported.") + + if not use_hf_whisper: + from faster_whisper import WhisperModel + + logger.info( + f"Loading faster-whisper model ({model}) with compute_type={compute_type}" + ) + + models = {} + + # 使用するGPUの数だけモデルを作成する。 + for device_index in device_indexes: + try: + model_object = WhisperModel(model, device=device, device_index=device_index, compute_type=compute_type) + except ValueError as e: + logger.warning(f"Failed to load model, so use `auto` compute_type: {e}") + model_object = WhisperModel(model, device=device, device_index=device_index) + models[device_index]=model_object + # 使用可能なモデルのキューを入れる + device_queue.put(device_index) + + # マルチスレッド開始 + threads = [] + for wav_file in tqdm(wav_files): + while True: + # 使用可能なモデルが無ければループする。 + if not device_queue.empty(): + device_index = device_queue.get() + thread = Thread(target=transcribe_thread, args=( + models[device_index], + device_index, + wav_file, + output_file, + model_name, + language_id, + initial_prompt, + language, + num_beams, + no_repeat_ngram_size + )) + thread.start() + threads.append(thread) + break + time.sleep(0.01) + + for thread in threads: + thread.join() + + # モデルの解放 + for device_index in device_indexes: + models[device_index]=None + + ''' + try: + model = WhisperModel(args.model, device=device, compute_type=compute_type) + except ValueError as e: + logger.warning(f"Failed to load model, so use `auto` compute_type: {e}") + model = WhisperModel(args.model, device=device) + for wav_file in tqdm(wav_files, file=SAFE_STDOUT): + text = transcribe_with_faster_whisper( + model=model, + audio_file=wav_file, + initial_prompt=initial_prompt, + language=language, + num_beams=num_beams, + no_repeat_ngram_size=no_repeat_ngram_size, + ) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") + ''' + else: + model_id = f"openai/whisper-{model}" + logger.info(f"Loading HF Whisper model ({model_id})") + pbar = tqdm(total=len(wav_files), file=SAFE_STDOUT) + results = transcribe_files_with_hf_whisper( + audio_files=wav_files, + model_id=model_id, + initial_prompt=initial_prompt, + language=language, + batch_size=batch_size, + num_beams=num_beams, + no_repeat_ngram_size=no_repeat_ngram_size, + device=device, + pbar=pbar, + ) + with open(output_file, "w", encoding="utf-8") as f: + for wav_file, text in zip(wav_files, results): + f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") + + return True, "" if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -124,6 +316,23 @@ def transcribe_files_with_hf_whisper( parser.add_argument("--no_repeat_ngram_size", type=int, default=10) args = parser.parse_args() + run( + args.model_name, + args.model, + args.compute_type, + args.language, + args.initial_prompt, + args.device, + args.device_indexes, + args.use_hf_whisper, + args.batch_size, + args.num_beams, + args.no_repeat_ngram_size, + ) + + sys.exit(0) + +''' with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: path_config: dict[str, str] = yaml.safe_load(f.read()) dataset_root = Path(path_config["dataset_root"]) @@ -205,3 +414,4 @@ def transcribe_files_with_hf_whisper( f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") sys.exit(0) +'''