From c8f0981a01bd68d8e526edf4155a40fdea784dfe Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Sat, 9 Sep 2023 13:04:04 +0800 Subject: [PATCH 1/8] feat-stop-training --- .gitignore | 4 ++++ README.md | 2 ++ README_zh_CN.md | 2 ++ train.py | 13 +++++++++---- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 269f8828..07f692e4 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,7 @@ pretrain/ .vscode/launch.json trained/**/ + +configs/ +filelists/ +filelists/val.txt diff --git a/README.md b/README.md index d9e2f793..7ee558c5 100644 --- a/README.md +++ b/README.md @@ -347,6 +347,8 @@ After completing the above steps, the dataset directory will contain the preproc python train.py -c configs/config.json -m 44k ``` +Here is a method that allows you to pause training and save the model at any time. Simply create a `stop.txt` file in the running directory, and the training will stop and save the model after the next step. + ### Diffusion Model (optional) If the shallow diffusion function is needed, the diffusion model needs to be trained. The diffusion model training method is as follows: diff --git a/README_zh_CN.md b/README_zh_CN.md index 71c3a60a..79522b5f 100644 --- a/README_zh_CN.md +++ b/README_zh_CN.md @@ -350,6 +350,8 @@ python preprocess_hubert_f0.py --f0_predictor dio --use_diff --num_processes 8 python train.py -c configs/config.json -m 44k ``` +这里提供了一个可以随时暂停训练并保存模型的方法,在运行目录下面新建一个 `stop.txt` 即可,训练会在下一个 step 结束后停止并保存模型 + ### 扩散模型(可选) 尚若需要浅扩散功能,需要训练扩散模型,扩散模型训练方法为: diff --git a/train.py b/train.py index 8487f177..d523700b 100644 --- a/train.py +++ b/train.py @@ -146,7 +146,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade global global_step net_g.train() - net_d.train() + net_d.train() for batch_idx, items in enumerate(train_loader): c, f0, spec, y, spk, lengths, uv,volume = items g = spk.cuda(rank, non_blocking=True) @@ -252,8 +252,10 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade images=image_dict, scalars=scalar_dict ) - - if global_step % hps.train.eval_interval == 0: + # 达到保存步数或者 stop 文件存在 + if global_step % hps.train.eval_interval == 0 or os.path.exists("stop.txt"): + if os.path.exists("stop.txt"): + logger.info("stop.txt found, stop training") evaluate(hps, net_g, eval_loader, writer_eval) utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) @@ -262,7 +264,10 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade keep_ckpts = getattr(hps.train, 'keep_ckpts', 0) if keep_ckpts > 0: utils.clean_checkpoints(path_to_models=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) - + if os.path.exists("stop.txt"): + logger.info("good bye!") + os.remove("stop.txt") + os._exit(0) global_step += 1 if rank == 0: From d1f07bc318d56327e4c0dae46ea42c1402386d1b Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Wed, 4 Oct 2023 17:36:24 +0800 Subject: [PATCH 2/8] feat(infer): auto complete --- README_zh_CN.md | 12 ++++++------ inference_main.py | 27 +++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/README_zh_CN.md b/README_zh_CN.md index 79522b5f..d66b92ac 100644 --- a/README_zh_CN.md +++ b/README_zh_CN.md @@ -1,7 +1,7 @@
LOGO -# SoftVC VITS Singing Voice Conversion +# SoftVC VITS Singing Voice Conversion(HuanLin Ver) [**English**](./README.md) | [**中文简体**](./README_zh_CN.md) @@ -372,14 +372,14 @@ python inference_main.py -m "logs/44k/G_30400.pth" -c "configs/config.json" -n " ``` 必填项部分: -+ `-m` | `--model_path`:模型路径 -+ `-c` | `--config_path`:配置文件路径 -+ `-n` | `--clean_names`:wav 文件名列表,放在 raw 文件夹下 ++ `-m` | `--model_path`:模型路径 (接受参数 `{lastest}`,示例: `"logs/44k/{lastest}"`,将获取最新保存点) ++ `-c` | `--config_path`:配置文件路径,缺省时使用 `logs/44k/config.json` + `-t` | `--trans`:音高调整,支持正负(半音) -+ `-s` | `--spk_list`:合成目标说话人名称 -+ `-cl` | `--clip`:音频强制切片,默认 0 为自动切片,单位为秒/s ++ `-s` | `--spk_list`:合成目标说话人名称,缺省时我也不知道怎么干 ++ `-n` | `--clean_names`:wav 文件名列表,放在 raw 文件夹下,缺省时 fallback 到 raw 文件夹中所有 wav 文件 可选项部分:部分具体见下一节 ++ `-cl` | `--clip`:音频强制切片,默认 0 为自动切片,单位为秒/s + `-lg` | `--linear_gradient`:两段音频切片的交叉淡入长度,如果强制切片后出现人声不连贯可调整该数值,如果连贯建议采用默认值 0,单位为秒 + `-f0p` | `--f0_predictor`:选择 F0 预测器,可选择 crepe,pm,dio,harvest,rmvpe,fcpe, 默认为 pm(注意:crepe 为原 F0 使用均值滤波器) + `-a` | `--auto_predict_f0`:语音转换自动预测音高,转换歌声时不要打开这个会严重跑调 diff --git a/inference_main.py b/inference_main.py index a99f6ec4..4c9c5649 100644 --- a/inference_main.py +++ b/inference_main.py @@ -1,5 +1,7 @@ import logging +import os + import soundfile from inference import infer_tool @@ -10,6 +12,17 @@ chunks_dict = infer_tool.read_temp("inference/chunks_temp.json") +def isWavFile(file_name): + return file_name.endswith('.wav') + +def getLastestCheckpoint(path): + files = [ f for f in os.listdir(path) if f.endswith('.pth') and f.startswith('G_')] + if len(files) == 0: + logging.error(f"no checkpoint in {path}") + return None + latest_file = max(files) + return os.path.join(path, latest_file) + def main(): import argparse @@ -17,14 +30,14 @@ def main(): parser = argparse.ArgumentParser(description='sovits4 inference') # 一定要设置的部分 - parser.add_argument('-m', '--model_path', type=str, default="logs/44k/G_37600.pth", help='模型路径') + parser.add_argument('-m', '--model_path', type=str, default="logs/44k/{lastest}", help='模型路径') parser.add_argument('-c', '--config_path', type=str, default="logs/44k/config.json", help='配置文件路径') - parser.add_argument('-cl', '--clip', type=float, default=0, help='音频强制切片,默认0为自动切片,单位为秒/s') - parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=["君の知らない物語-src.wav"], help='wav文件名列表,放在raw文件夹下') parser.add_argument('-t', '--trans', type=int, nargs='+', default=[0], help='音高调整,支持正负(半音)') parser.add_argument('-s', '--spk_list', type=str, nargs='+', default=['buyizi'], help='合成目标说话人名称') # 可选项部分 + parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=[f for f in os.listdir("raw") if isWavFile(f)], help='wav文件名列表,放在raw文件夹下且后缀为.wav') + parser.add_argument('-cl', '--clip', type=float, default=0, help='音频强制切片,默认0为自动切片,单位为秒/s') parser.add_argument('-a', '--auto_predict_f0', action='store_true', default=False, help='语音转换自动预测音高,转换歌声时不要打开这个会严重跑调') parser.add_argument('-cm', '--cluster_model_path', type=str, default="", help='聚类模型或特征检索索引路径,留空则自动设为各方案模型的默认路径,如果没有训练聚类或特征检索则随便填') parser.add_argument('-cr', '--cluster_infer_ratio', type=float, default=0, help='聚类方案或特征检索占比,范围0-1,若没有训练聚类模型或特征检索则默认0即可') @@ -82,6 +95,12 @@ def main(): second_encoding = args.second_encoding loudness_envelope_adjustment = args.loudness_envelope_adjustment + model_path = args.model_path + + if "{lastest}" in model_path: + model_path = getLastestCheckpoint(model_path[:model_path.find("{lastest}")]) + logging.info("Auto choose {}", model_path) + if cluster_infer_ratio != 0: if args.cluster_model_path == "": if args.feature_retrieval: # 若指定了占比但没有指定模型路径,则按是否使用特征检索分配默认的模型路径 @@ -91,7 +110,7 @@ def main(): else: # 若未指定占比,则无论是否指定模型路径,都将其置空以避免之后的模型加载 args.cluster_model_path = "" - svc_model = Svc(args.model_path, + svc_model = Svc(model_path, args.config_path, args.device, args.cluster_model_path, From b2aa17b36bae537c0f5bf9dd69d4b1c8d9172ee5 Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Sun, 22 Oct 2023 01:37:12 +0800 Subject: [PATCH 3/8] chore(logger): infer --- inference/infer_tool.py | 103 +++++++++++++++++++++------------------- inference_main.py | 15 ++++-- logger/__init__.py | 37 +++++++++++++++ test.py | 16 +++++++ 4 files changed, 119 insertions(+), 52 deletions(-) create mode 100644 logger/__init__.py create mode 100644 test.py diff --git a/inference/infer_tool.py b/inference/infer_tool.py index 06ebcaea..cbfb4c11 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -22,6 +22,8 @@ from inference import slicer from models import SynthesizerTrn +import logger + logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -67,6 +69,7 @@ def format_wav(audio_path): if Path(audio_path).suffix == '.wav': return raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None) + soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate) @@ -443,56 +446,58 @@ def slice_inference(self, global_frame = 0 audio = [] - for (slice_tag, data) in audio_data: - print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======') - # padd - length = int(np.ceil(len(data) / audio_sr * self.target_sample)) - if slice_tag: - print('jump empty segment') - _audio = np.zeros(length) - audio.extend(list(pad_array(_audio, length))) - global_frame += length // self.hop_size - continue - if per_size != 0: - datas = split_list_by_n(data, per_size,lg_size) - else: - datas = [data] - for k,dat in enumerate(datas): - per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length - if clip_seconds!=0: - print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======') + with logger.Progress() as progress: + for (slice_tag, data) in progress.track(audio_data): + logger.info(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======') # padd - pad_len = int(audio_sr * pad_seconds) - dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) - raw_path = io.BytesIO() - soundfile.write(raw_path, dat, audio_sr, format="wav") - raw_path.seek(0) - out_audio, out_sr, out_frame = self.infer(spk, tran, raw_path, - cluster_infer_ratio=cluster_infer_ratio, - auto_predict_f0=auto_predict_f0, - noice_scale=noice_scale, - f0_predictor = f0_predictor, - enhancer_adaptive_key = enhancer_adaptive_key, - cr_threshold = cr_threshold, - k_step = k_step, - frame = global_frame, - spk_mix = use_spk_mix, - second_encoding = second_encoding, - loudness_envelope_adjustment = loudness_envelope_adjustment - ) - global_frame += out_frame - _audio = out_audio.cpu().numpy() - pad_len = int(self.target_sample * pad_seconds) - _audio = _audio[pad_len:-pad_len] - _audio = pad_array(_audio, per_length) - if lg_size!=0 and k!=0: - lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr_num != 1 else audio[-lg_size:] - lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr_num != 1 else _audio[0:lg_size] - lg_pre = lg1*(1-lg)+lg2*lg - audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr_num != 1 else audio[0:-lg_size] - audio.extend(lg_pre) - _audio = _audio[lg_size_c_l+lg_size_r:] if lgr_num != 1 else _audio[lg_size:] - audio.extend(list(_audio)) + length = int(np.ceil(len(data) / audio_sr * self.target_sample)) + if slice_tag: + logger.info('jump empty segment') + _audio = np.zeros(length) + audio.extend(list(pad_array(_audio, length))) + global_frame += length // self.hop_size + continue + if per_size != 0: + datas = split_list_by_n(data, per_size,lg_size) + else: + datas = [data] + for k,dat in enumerate(datas): + per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length + if clip_seconds!=0: + logger.info(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======') + # padd + pad_len = int(audio_sr * pad_seconds) + dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) + raw_path = io.BytesIO() + soundfile.write(raw_path, dat, audio_sr, format="wav") + raw_path.seek(0) + out_audio, out_sr, out_frame = self.infer(spk, tran, raw_path, + cluster_infer_ratio=cluster_infer_ratio, + auto_predict_f0=auto_predict_f0, + noice_scale=noice_scale, + f0_predictor = f0_predictor, + enhancer_adaptive_key = enhancer_adaptive_key, + cr_threshold = cr_threshold, + k_step = k_step, + frame = global_frame, + spk_mix = use_spk_mix, + second_encoding = second_encoding, + loudness_envelope_adjustment = loudness_envelope_adjustment + ) + global_frame += out_frame + _audio = out_audio.cpu().numpy() + pad_len = int(self.target_sample * pad_seconds) + _audio = _audio[pad_len:-pad_len] + _audio = pad_array(_audio, per_length) + if lg_size!=0 and k!=0: + lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr_num != 1 else audio[-lg_size:] + lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr_num != 1 else _audio[0:lg_size] + lg_pre = lg1*(1-lg)+lg2*lg + audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr_num != 1 else audio[0:-lg_size] + audio.extend(lg_pre) + _audio = _audio[lg_size_c_l+lg_size_r:] if lgr_num != 1 else _audio[lg_size:] + audio.extend(list(_audio)) + return np.array(audio) class RealTimeVC: diff --git a/inference_main.py b/inference_main.py index 4c9c5649..4c8c7f7c 100644 --- a/inference_main.py +++ b/inference_main.py @@ -1,5 +1,9 @@ import logging +# from loguru import logger +import logger + + import os import soundfile @@ -18,7 +22,7 @@ def isWavFile(file_name): def getLastestCheckpoint(path): files = [ f for f in os.listdir(path) if f.endswith('.pth') and f.startswith('G_')] if len(files) == 0: - logging.error(f"no checkpoint in {path}") + logger.error(f"no checkpoint in {path}") return None latest_file = max(files) return os.path.join(path, latest_file) @@ -99,7 +103,7 @@ def main(): if "{lastest}" in model_path: model_path = getLastestCheckpoint(model_path[:model_path.find("{lastest}")]) - logging.info("Auto choose {}", model_path) + logger.info("Auto choose {}", model_path) if cluster_infer_ratio != 0: if args.cluster_model_path == "": @@ -156,7 +160,12 @@ def main(): "second_encoding":second_encoding, "loudness_envelope_adjustment":loudness_envelope_adjustment } - audio = svc_model.slice_inference(**kwarg) + try: + audio = svc_model.slice_inference(**kwarg) + except Exception as e: + logger.error(f"Error in {clean_name} {spk} {tran}") + logger.error(e) + return key = "auto" if auto_predict_f0 else f"{tran}key" cluster_name = "" if cluster_infer_ratio == 0 else f"_{cluster_infer_ratio}" isdiffusion = "sovits" diff --git a/logger/__init__.py b/logger/__init__.py new file mode 100644 index 00000000..05cdc0ce --- /dev/null +++ b/logger/__init__.py @@ -0,0 +1,37 @@ +from loguru import logger + +from rich.console import Console +from rich.logging import RichHandler +from rich.progress import Progress as _Progress + +console = Console(stderr=None) +# logger.remove() +# handler = RichHandler(console=console) +# # logger.add( +# # # lambda _: console.print(_, end=''), +# # handler, +# # level='TRACE', +# # # format=log_formater, +# # format= "[green]{time:YYYY-MM-DD HH:mm:ss.SSS}[/green] | " +# # "[level]{level: <8}[/level] | " +# # "[cyan]{name}[/cyan]:[cyan]{function}[/cyan]:[cyan]{line}[/cyan] - [level]{message}[/level]", +# # colorize=True, +# # ) +# logger.add(handler, colorize=True) + +logger.remove() # Remove default 'stderr' handler + +# We need to specify end=''" as log message already ends with \n (thus the lambda function) +# Also forcing 'colorize=True' otherwise Loguru won't recognize that the sink support colors +logger.add(lambda m: console.print(m, end=""), + format="[green]{time:YYYY-MM-DD HH:mm:ss.SSS}[/green] | " + "[level]{level: <8}[/level] | " + "[cyan]{name}[/cyan]:[cyan]{function}[/cyan]:[cyan]{line}[/cyan] - [level]{message}[/level]", + colorize=True) +info = logger.info +error = logger.error +warning = logger.warning +debug = logger.debug + +def Progress(): + return _Progress(console=console) \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 00000000..4694c8c4 --- /dev/null +++ b/test.py @@ -0,0 +1,16 @@ +import time +import logger + +def work(i): + logger.info("[red]Working[/red]: {}", i) + time.sleep(0.05) + +def main(): + logger.info("Before") + with logger.Progress() as progress: + for i in progress.track(range(100)): + work(i) + logger.info("After") + +if __name__ == "__main__": + main() \ No newline at end of file From 644867553a8e7f28cc6b5272dbdbd0a1e039323f Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Sun, 22 Oct 2023 13:37:44 +0800 Subject: [PATCH 4/8] chore(logger): more --- configs/config.json | 107 ++++++++++++++++++++++++++++++ configs/diffusion.yaml | 51 ++++++++++++++ filelists/train.txt | 16 +---- filelists/val.txt | 6 +- inference/infer_tool.py | 4 +- logger/__init__.py | 40 +++++------ preprocess_flist_config.py | 16 +++-- test.py | 2 +- utils.py | 30 +++++---- vencoder/CNHubertLarge.py | 3 +- vencoder/ContentVec256L12_Onnx.py | 3 +- vencoder/ContentVec256L9.py | 3 +- vencoder/ContentVec256L9_Onnx.py | 3 +- vencoder/ContentVec768L12.py | 3 +- vencoder/ContentVec768L12_Onnx.py | 3 +- vencoder/ContentVec768L9_Onnx.py | 3 +- vencoder/DPHubert.py | 3 +- vencoder/HubertSoft.py | 3 +- vencoder/HubertSoft_Onnx.py | 3 +- vencoder/WavLMBasePlus.py | 3 +- 20 files changed, 232 insertions(+), 73 deletions(-) diff --git a/configs/config.json b/configs/config.json index e69de29b..09a7b372 100644 --- a/configs/config.json +++ b/configs/config.json @@ -0,0 +1,107 @@ +{ + "train": { + "log_interval": 200, + "eval_interval": 800, + "seed": 1234, + "epochs": 10000, + "learning_rate": 0.0001, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-09, + "batch_size": 6, + "fp16_run": false, + "half_type": "fp16", + "lr_decay": 0.999875, + "segment_size": 10240, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "use_sr": true, + "max_speclen": 512, + "port": "8001", + "keep_ckpts": 3, + "all_in_mem": false, + "vol_aug": false + }, + "data": { + "training_files": "filelists/train.txt", + "validation_files": "filelists/val.txt", + "max_wav_value": 32768.0, + "sampling_rate": 44100, + "filter_length": 2048, + "hop_length": 512, + "win_length": 2048, + "n_mel_channels": 80, + "mel_fmin": 0.0, + "mel_fmax": 22050, + "unit_interpolate_mode": "nearest" + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4, + 4 + ], + "n_layers_q": 3, + "n_layers_trans_flow": 3, + "n_flow_layer": 4, + "use_spectral_norm": false, + "gin_channels": 768, + "ssl_dim": 768, + "n_speakers": 1, + "vocoder_name": "nsf-hifigan", + "speech_encoder": "vec768l12", + "speaker_embedding": false, + "vol_embedding": false, + "use_depthwise_conv": false, + "flow_share_parameter": false, + "use_automatic_f0_prediction": true, + "use_transformer_flow": false + }, + "spk": { + "TEST": 0 + } +} \ No newline at end of file diff --git a/configs/diffusion.yaml b/configs/diffusion.yaml index e69de29b..51c0046a 100644 --- a/configs/diffusion.yaml +++ b/configs/diffusion.yaml @@ -0,0 +1,51 @@ +data: + block_size: 512 + cnhubertsoft_gate: 10 + duration: 2 + encoder: vec768l12 + encoder_hop_size: 320 + encoder_out_channels: 768 + encoder_sample_rate: 16000 + extensions: + - wav + sampling_rate: 44100 + training_files: filelists/train.txt + unit_interpolate_mode: nearest + validation_files: filelists/val.txt +device: cuda +env: + expdir: logs/44k/diffusion + gpu_id: 0 +infer: + method: dpm-solver++ + speedup: 10 +model: + k_step_max: 0 + n_chans: 512 + n_hidden: 256 + n_layers: 20 + n_spk: 1 + timesteps: 1000 + type: Diffusion + use_pitch_aug: true +spk: + TEST: 0 +train: + amp_dtype: fp32 + batch_size: 48 + cache_all_data: true + cache_device: cpu + cache_fp16: true + decay_step: 100000 + epochs: 100000 + gamma: 0.5 + interval_force_save: 5000 + interval_log: 10 + interval_val: 2000 + lr: 0.0001 + num_workers: 4 + save_opt: false + weight_decay: 0 +vocoder: + ckpt: pretrain/nsf_hifigan/model + type: nsf-hifigan diff --git a/filelists/train.txt b/filelists/train.txt index acdb3cce..3a391b43 100644 --- a/filelists/train.txt +++ b/filelists/train.txt @@ -1,15 +1 @@ -./dataset/44k/taffy/000549.wav -./dataset/44k/nyaru/000004.wav -./dataset/44k/nyaru/000006.wav -./dataset/44k/taffy/000551.wav -./dataset/44k/nyaru/000009.wav -./dataset/44k/taffy/000561.wav -./dataset/44k/nyaru/000001.wav -./dataset/44k/taffy/000553.wav -./dataset/44k/nyaru/000002.wav -./dataset/44k/taffy/000560.wav -./dataset/44k/taffy/000557.wav -./dataset/44k/nyaru/000005.wav -./dataset/44k/taffy/000554.wav -./dataset/44k/taffy/000550.wav -./dataset/44k/taffy/000559.wav +./dataset/44k/TEST/vo_tips_TCG001_4_littlePrince_02.wav diff --git a/filelists/val.txt b/filelists/val.txt index 262dfc97..b73d2fa5 100644 --- a/filelists/val.txt +++ b/filelists/val.txt @@ -1,4 +1,2 @@ -./dataset/44k/nyaru/000003.wav -./dataset/44k/nyaru/000007.wav -./dataset/44k/taffy/000558.wav -./dataset/44k/taffy/000556.wav +./dataset/44k/TEST/vo_tips_TCG001_37_littlePrince_01.wav +./dataset/44k/TEST/vo_tips_TCG001_4_littlePrince_01.wav diff --git a/inference/infer_tool.py b/inference/infer_tool.py index cbfb4c11..19956660 100644 --- a/inference/infer_tool.py +++ b/inference/infer_tool.py @@ -448,7 +448,7 @@ def slice_inference(self, audio = [] with logger.Progress() as progress: for (slice_tag, data) in progress.track(audio_data): - logger.info(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======') + logger.info(f'segment start, {round(len(data) / audio_sr, 3)}s') # padd length = int(np.ceil(len(data) / audio_sr * self.target_sample)) if slice_tag: @@ -464,7 +464,7 @@ def slice_inference(self, for k,dat in enumerate(datas): per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length if clip_seconds!=0: - logger.info(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======') + logger.info(f'segment clip start, {round(len(dat) / audio_sr, 3)}s') # padd pad_len = int(audio_sr * pad_seconds) dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) diff --git a/logger/__init__.py b/logger/__init__.py index 05cdc0ce..75db53fe 100644 --- a/logger/__init__.py +++ b/logger/__init__.py @@ -3,35 +3,37 @@ from rich.console import Console from rich.logging import RichHandler from rich.progress import Progress as _Progress +from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TimeRemainingColumn + +import time console = Console(stderr=None) -# logger.remove() -# handler = RichHandler(console=console) -# # logger.add( -# # # lambda _: console.print(_, end=''), -# # handler, -# # level='TRACE', -# # # format=log_formater, -# # format= "[green]{time:YYYY-MM-DD HH:mm:ss.SSS}[/green] | " -# # "[level]{level: <8}[/level] | " -# # "[cyan]{name}[/cyan]:[cyan]{function}[/cyan]:[cyan]{line}[/cyan] - [level]{message}[/level]", -# # colorize=True, -# # ) -# logger.add(handler, colorize=True) - -logger.remove() # Remove default 'stderr' handler - -# We need to specify end=''" as log message already ends with \n (thus the lambda function) -# Also forcing 'colorize=True' otherwise Loguru won't recognize that the sink support colors + +logger.remove() + logger.add(lambda m: console.print(m, end=""), format="[green]{time:YYYY-MM-DD HH:mm:ss.SSS}[/green] | " "[level]{level: <8}[/level] | " "[cyan]{name}[/cyan]:[cyan]{function}[/cyan]:[cyan]{line}[/cyan] - [level]{message}[/level]", colorize=True) + +def addLogger(path): + logger.add(path, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} - {message}", + colorize=True) + info = logger.info error = logger.error warning = logger.warning debug = logger.debug def Progress(): - return _Progress(console=console) \ No newline at end of file + return _Progress( + TextColumn("[progress.description]{task.description}"), + # TextColumn("[progress.description]W"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + TextColumn("[red]*Elapsed[/red]"), + TimeElapsedColumn(),console=console) \ No newline at end of file diff --git a/preprocess_flist_config.py b/preprocess_flist_config.py index 1335d7d7..ee19679b 100644 --- a/preprocess_flist_config.py +++ b/preprocess_flist_config.py @@ -5,7 +5,7 @@ import wave from random import shuffle -from loguru import logger +import logger from tqdm import tqdm import diffusion.logger.utils as du @@ -73,15 +73,17 @@ def get_wav_duration(file_path): logger.info("Writing " + args.train_list) with open(args.train_list, "w") as f: - for fname in tqdm(train): - wavpath = fname - f.write(wavpath + "\n") + with logger.Progress() as progress: + for fname in progress.track(train): + wavpath = fname + f.write(wavpath + "\n") logger.info("Writing " + args.val_list) with open(args.val_list, "w") as f: - for fname in tqdm(val): - wavpath = fname - f.write(wavpath + "\n") + with logger.Progress() as progress: + for fname in progress.track(val): + wavpath = fname + f.write(wavpath + "\n") d_config_template = du.load_config("configs_template/diffusion_template.yaml") diff --git a/test.py b/test.py index 4694c8c4..a4888390 100644 --- a/test.py +++ b/test.py @@ -2,7 +2,7 @@ import logger def work(i): - logger.info("[red]Working[/red]: {}", i) + # logger.info("[red]Working[/red]: {}", i) time.sleep(0.05) def main(): diff --git a/utils.py b/utils.py index 95b6d888..ca1f349c 100644 --- a/utils.py +++ b/utils.py @@ -17,10 +17,12 @@ from sklearn.cluster import MiniBatchKMeans from torch.nn import functional as F +import logger + MATPLOTLIB_FLAG = False logging.basicConfig(stream=sys.stdout, level=logging.WARN) -logger = logging +# logger = logging f0_bin = 256 f0_max = 1100.0 @@ -174,14 +176,13 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) except Exception: if "enc_q" not in k or "emb_g" not in k: - print("%s is not in the checkpoint,please check your checkpoint.If you're using pretrain model,just ignore this warning." % k) + logger.error("%s is not in the checkpoint,please check your checkpoint.If you're using pretrain model,just ignore this warning." % k) logger.info("%s is not in the checkpoint" % k) new_state_dict[k] = v if hasattr(model, 'module'): model.module.load_state_dict(new_state_dict) else: model.load_state_dict(new_state_dict) - print("load ") logger.info("Loaded checkpoint '{}' (iteration {})".format( checkpoint_path, iteration)) return model, optimizer, learning_rate, iteration @@ -379,17 +380,18 @@ def check_git_hash(model_dir): def get_logger(model_dir, filename="train.log"): - global logger - logger = logging.getLogger(os.path.basename(model_dir)) - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - h = logging.FileHandler(os.path.join(model_dir, filename)) - h.setLevel(logging.DEBUG) - h.setFormatter(formatter) - logger.addHandler(h) + import logger + logger.addLogger(os.path.join(model_dir, filename)) + # logger = logging.getLogger(os.path.basename(model_dir)) + # logger.setLevel(logging.DEBUG) + + # formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + # if not os.path.exists(model_dir): + # os.makedirs(model_dir) + # h = logging.FileHandler(os.path.join(model_dir, filename)) + # h.setLevel(logging.DEBUG) + # h.setFormatter(formatter) + # logger.addHandler(h) return logger diff --git a/vencoder/CNHubertLarge.py b/vencoder/CNHubertLarge.py index f4369476..fdaa0c20 100644 --- a/vencoder/CNHubertLarge.py +++ b/vencoder/CNHubertLarge.py @@ -7,7 +7,8 @@ class CNHubertLarge(SpeechEncoder): def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 1024 models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], diff --git a/vencoder/ContentVec256L12_Onnx.py b/vencoder/ContentVec256L12_Onnx.py index 466e6c12..601da1ea 100644 --- a/vencoder/ContentVec256L12_Onnx.py +++ b/vencoder/ContentVec256L12_Onnx.py @@ -7,7 +7,8 @@ class ContentVec256L12_Onnx(SpeechEncoder): def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/ContentVec256L9.py b/vencoder/ContentVec256L9.py index c973090d..dfcab8b1 100644 --- a/vencoder/ContentVec256L9.py +++ b/vencoder/ContentVec256L9.py @@ -7,7 +7,8 @@ class ContentVec256L9(SpeechEncoder): def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], suffix="", diff --git a/vencoder/ContentVec256L9_Onnx.py b/vencoder/ContentVec256L9_Onnx.py index a27e1f76..78173da4 100644 --- a/vencoder/ContentVec256L9_Onnx.py +++ b/vencoder/ContentVec256L9_Onnx.py @@ -7,7 +7,8 @@ class ContentVec256L9_Onnx(SpeechEncoder): def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/ContentVec768L12.py b/vencoder/ContentVec768L12.py index 066b824b..ad06a979 100644 --- a/vencoder/ContentVec768L12.py +++ b/vencoder/ContentVec768L12.py @@ -7,7 +7,8 @@ class ContentVec768L12(SpeechEncoder): def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], diff --git a/vencoder/ContentVec768L12_Onnx.py b/vencoder/ContentVec768L12_Onnx.py index e7375945..cf4e2705 100644 --- a/vencoder/ContentVec768L12_Onnx.py +++ b/vencoder/ContentVec768L12_Onnx.py @@ -7,7 +7,8 @@ class ContentVec768L12_Onnx(SpeechEncoder): def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/ContentVec768L9_Onnx.py b/vencoder/ContentVec768L9_Onnx.py index 3bd0f337..a203eb52 100644 --- a/vencoder/ContentVec768L9_Onnx.py +++ b/vencoder/ContentVec768L9_Onnx.py @@ -7,7 +7,8 @@ class ContentVec768L9_Onnx(SpeechEncoder): def __init__(self,vec_path = "pretrain/vec-768-layer-9.onnx",device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/DPHubert.py b/vencoder/DPHubert.py index 130064ff..540eb58c 100644 --- a/vencoder/DPHubert.py +++ b/vencoder/DPHubert.py @@ -7,7 +7,8 @@ class DPHubert(SpeechEncoder): def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: diff --git a/vencoder/HubertSoft.py b/vencoder/HubertSoft.py index 423c159c..b3b01442 100644 --- a/vencoder/HubertSoft.py +++ b/vencoder/HubertSoft.py @@ -7,7 +7,8 @@ class HubertSoft(SpeechEncoder): def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) hubert_soft = hubert_model.hubert_soft(vec_path) if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/vencoder/HubertSoft_Onnx.py b/vencoder/HubertSoft_Onnx.py index 038d78e8..52fadf95 100644 --- a/vencoder/HubertSoft_Onnx.py +++ b/vencoder/HubertSoft_Onnx.py @@ -7,7 +7,8 @@ class HubertSoft_Onnx(SpeechEncoder): def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/WavLMBasePlus.py b/vencoder/WavLMBasePlus.py index 99df15be..8dc82592 100644 --- a/vencoder/WavLMBasePlus.py +++ b/vencoder/WavLMBasePlus.py @@ -7,7 +7,8 @@ class WavLMBasePlus(SpeechEncoder): def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None): super().__init__() - print("load model(s) from {}".format(vec_path)) + import logger + logger.info("load model(s) from {}".format(vec_path)) checkpoint = torch.load(vec_path) self.cfg = WavLMConfig(checkpoint['cfg']) if device is None: From 24a840c30f4ff351c4de2b1a36acba98d49e6675 Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Sat, 28 Oct 2023 15:30:07 +0800 Subject: [PATCH 5/8] chore(logger): train --- filelists/train.txt | 15 +++ filelists/val.txt | 4 +- logger/__init__.py | 69 ++++++++--- preprocess_flist_config.py | 120 +++++++++---------- test.py | 9 +- train.py | 236 +++++++++++++++++++------------------ 6 files changed, 255 insertions(+), 198 deletions(-) diff --git a/filelists/train.txt b/filelists/train.txt index 3a391b43..9c94dbec 100644 --- a/filelists/train.txt +++ b/filelists/train.txt @@ -1 +1,16 @@ +./dataset/44k/TEST/354440190 (9).wav +./dataset/44k/TEST/354440190 (6).wav +./dataset/44k/TEST/354440190 (1).wav +./dataset/44k/TEST/354440190 (11).wav +./dataset/44k/TEST/354440190 (10).wav +./dataset/44k/TEST/354440190 (15).wav +./dataset/44k/TEST/354440190 (12).wav +./dataset/44k/TEST/354440190 (2).wav +./dataset/44k/TEST/354440190 (7).wav +./dataset/44k/TEST/vo_tips_TCG001_4_littlePrince_01.wav +./dataset/44k/TEST/vo_tips_TCG001_37_littlePrince_01.wav +./dataset/44k/TEST/354440190 (5).wav ./dataset/44k/TEST/vo_tips_TCG001_4_littlePrince_02.wav +./dataset/44k/TEST/354440190 (4).wav +./dataset/44k/TEST/354440190 (14).wav +./dataset/44k/TEST/354440190 (13).wav diff --git a/filelists/val.txt b/filelists/val.txt index b73d2fa5..feb10a1f 100644 --- a/filelists/val.txt +++ b/filelists/val.txt @@ -1,2 +1,2 @@ -./dataset/44k/TEST/vo_tips_TCG001_37_littlePrince_01.wav -./dataset/44k/TEST/vo_tips_TCG001_4_littlePrince_01.wav +./dataset/44k/TEST/354440190 (3).wav +./dataset/44k/TEST/354440190 (8).wav diff --git a/logger/__init__.py b/logger/__init__.py index 75db53fe..fed6afa9 100644 --- a/logger/__init__.py +++ b/logger/__init__.py @@ -3,37 +3,72 @@ from rich.console import Console from rich.logging import RichHandler from rich.progress import Progress as _Progress -from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) import time +import os +import datetime + console = Console(stderr=None) logger.remove() -logger.add(lambda m: console.print(m, end=""), - format="[green]{time:YYYY-MM-DD HH:mm:ss.SSS}[/green] | " - "[level]{level: <8}[/level] | " - "[cyan]{name}[/cyan]:[cyan]{function}[/cyan]:[cyan]{line}[/cyan] - [level]{message}[/level]", - colorize=True) + +def format_level(str, length): + if len(str) < length: + str = str + " " * (length - len(str)) + else: + str = str + # 给 str 上对应 level 的颜色 + if str == "INFO ": + str = f"[bold green]{str}[/bold green]" + elif str == "WARNING": + str = f"[bold yellow]{str}[/bold yellow]" + elif str == "ERROR ": + str = f"[bold red]{str}[/bold red]" + elif str == "DEBUG ": + str = f"[bold cyan]{str}[/bold cyan]" + return str + +def default_format(record): + # print(record) + return f"[green]{record['time'].strftime('%Y-%m-%d %H:%M:%S')}[/green] | [level]{format_level(record['level'].name,7)}[/level] | [cyan]{record['file'].path.replace(os.getcwd()+os.sep,'')}[/cyan] - [level]{record['message']}[/level]\n" + + +logger.add(lambda m: console.print(m, end=""), format=default_format, colorize=True) + def addLogger(path): - logger.add(path, format="{time:YYYY-MM-DD HH:mm:ss.SSS} | " - "{level: <8} | " - "{name}:{function}:{line} - {message}", - colorize=True) + logger.add( + path, + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} - {message}", + colorize=True, + ) + info = logger.info error = logger.error warning = logger.warning debug = logger.debug + def Progress(): return _Progress( - TextColumn("[progress.description]{task.description}"), - # TextColumn("[progress.description]W"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeRemainingColumn(), - TextColumn("[red]*Elapsed[/red]"), - TimeElapsedColumn(),console=console) \ No newline at end of file + TextColumn("[progress.description]{task.description}"), + # TextColumn("[progress.description]W"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + TextColumn("[red]*Elapsed[/red]"), + TimeElapsedColumn(), + console=console, + ) diff --git a/preprocess_flist_config.py b/preprocess_flist_config.py index ee19679b..b1408701 100644 --- a/preprocess_flist_config.py +++ b/preprocess_flist_config.py @@ -10,6 +10,8 @@ import diffusion.logger.utils as du +import time + pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$') def get_wav_duration(file_path): @@ -41,81 +43,79 @@ def get_wav_duration(file_path): idx = 0 spk_dict = {} spk_id = 0 + with logger.Progress() as progress: + for speaker in progress.track(os.listdir(args.source_dir), description="Processing Speakers"): + spk_dict[speaker] = spk_id + spk_id += 1 + wavs = [] - for speaker in tqdm(os.listdir(args.source_dir)): - spk_dict[speaker] = spk_id - spk_id += 1 - wavs = [] - - for file_name in os.listdir(os.path.join(args.source_dir, speaker)): - if not file_name.endswith("wav"): - continue - if file_name.startswith("."): - continue + for file_name in os.listdir(os.path.join(args.source_dir, speaker)): + if not file_name.endswith("wav"): + continue + if file_name.startswith("."): + continue - file_path = "/".join([args.source_dir, speaker, file_name]) + file_path = "/".join([args.source_dir, speaker, file_name]) - if not pattern.match(file_name): - logger.warning("Detected non-ASCII file name: " + file_path) + if not pattern.match(file_name): + logger.warning("Detected non-ASCII file name: " + file_path) - if get_wav_duration(file_path) < 0.3: - logger.info("Skip too short audio: " + file_path) - continue + if get_wav_duration(file_path) < 0.3: + logger.info("Skip too short audio: " + file_path) + continue - wavs.append(file_path) + wavs.append(file_path) - shuffle(wavs) - train += wavs[2:] - val += wavs[:2] + shuffle(wavs) + train += wavs[2:] + val += wavs[:2] - shuffle(train) - shuffle(val) + shuffle(train) + shuffle(val) - logger.info("Writing " + args.train_list) - with open(args.train_list, "w") as f: - with logger.Progress() as progress: - for fname in progress.track(train): + logger.info("Writing " + args.train_list) + with open(args.train_list, "w") as f: + for fname in progress.track(train, description="Writing train list"): wavpath = fname f.write(wavpath + "\n") - logger.info("Writing " + args.val_list) - with open(args.val_list, "w") as f: - with logger.Progress() as progress: - for fname in progress.track(val): + logger.info("Writing " + args.val_list) + with open(args.val_list, "w") as f: + for fname in progress.track(val, description="Writing val list"): wavpath = fname f.write(wavpath + "\n") - d_config_template = du.load_config("configs_template/diffusion_template.yaml") - d_config_template["model"]["n_spk"] = spk_id - d_config_template["data"]["encoder"] = args.speech_encoder - d_config_template["spk"] = spk_dict - - config_template["spk"] = spk_dict - config_template["model"]["n_speakers"] = spk_id - config_template["model"]["speech_encoder"] = args.speech_encoder - - if args.speech_encoder == "vec768l12" or args.speech_encoder == "dphubert" or args.speech_encoder == "wavlmbase+": - config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768 - d_config_template["data"]["encoder_out_channels"] = 768 - elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft': - config_template["model"]["ssl_dim"] = config_template["model"]["gin_channels"] = 256 - d_config_template["data"]["encoder_out_channels"] = 256 - elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge': - config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024 - d_config_template["data"]["encoder_out_channels"] = 1024 - elif args.speech_encoder == "whisper-ppg-large": - config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1280 - d_config_template["data"]["encoder_out_channels"] = 1280 + d_config_template = du.load_config("configs_template/diffusion_template.yaml") + d_config_template["model"]["n_spk"] = spk_id + d_config_template["data"]["encoder"] = args.speech_encoder + d_config_template["spk"] = spk_dict + + config_template["spk"] = spk_dict + config_template["model"]["n_speakers"] = spk_id + config_template["model"]["speech_encoder"] = args.speech_encoder + + if args.speech_encoder == "vec768l12" or args.speech_encoder == "dphubert" or args.speech_encoder == "wavlmbase+": + config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768 + d_config_template["data"]["encoder_out_channels"] = 768 + elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft': + config_template["model"]["ssl_dim"] = config_template["model"]["gin_channels"] = 256 + d_config_template["data"]["encoder_out_channels"] = 256 + elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge': + config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024 + d_config_template["data"]["encoder_out_channels"] = 1024 + elif args.speech_encoder == "whisper-ppg-large": + config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1280 + d_config_template["data"]["encoder_out_channels"] = 1280 - if args.vol_aug: - config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True + if args.vol_aug: + config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True - if args.tiny: - config_template["model"]["filter_channels"] = 512 + if args.tiny: + config_template["model"]["filter_channels"] = 512 - logger.info("Writing to configs/config.json") - with open("configs/config.json", "w") as f: - json.dump(config_template, f, indent=2) - logger.info("Writing to configs/diffusion.yaml") - du.save_config("configs/diffusion.yaml",d_config_template) + logger.info("Writing to configs/config.json") + with open("configs/config.json", "w") as f: + json.dump(config_template, f, indent=2) + logger.info("Writing to configs/diffusion.yaml") + du.save_config("configs/diffusion.yaml",d_config_template) diff --git a/test.py b/test.py index a4888390..4a46abab 100644 --- a/test.py +++ b/test.py @@ -3,14 +3,19 @@ def work(i): # logger.info("[red]Working[/red]: {}", i) - time.sleep(0.05) + # time.sleep(0.) + pass def main(): logger.info("Before") with logger.Progress() as progress: - for i in progress.track(range(100)): + for i in progress.track(range(100), description="Workkkkkking"): work(i) logger.info("After") + logger.warning("Warning") + logger.error("Error") + logger.debug("Debug") + logger.info("Info") if __name__ == "__main__": main() \ No newline at end of file diff --git a/train.py b/train.py index d523700b..e2f4b868 100644 --- a/train.py +++ b/train.py @@ -147,128 +147,130 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade net_g.train() net_d.train() - for batch_idx, items in enumerate(train_loader): - c, f0, spec, y, spk, lengths, uv,volume = items - g = spk.cuda(rank, non_blocking=True) - spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) - c = c.cuda(rank, non_blocking=True) - f0 = f0.cuda(rank, non_blocking=True) - uv = uv.cuda(rank, non_blocking=True) - lengths = lengths.cuda(rank, non_blocking=True) - mel = spec_to_mel_torch( - spec, - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.mel_fmin, - hps.data.mel_fmax) - - with autocast(enabled=hps.train.fp16_run, dtype=half_type): - y_hat, ids_slice, z_mask, \ - (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths, - spec_lengths=lengths,vol = volume) - - y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1), + enumerated_train_loader = enumerate(train_loader) + with logger.Progress() as progress: + for batch_idx, items in progress.track(enumerated_train_loader, description="Epoch {}".format(epoch)): + c, f0, spec, y, spk, lengths, uv,volume = items + g = spk.cuda(rank, non_blocking=True) + spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) + c = c.cuda(rank, non_blocking=True) + f0 = f0.cuda(rank, non_blocking=True) + uv = uv.cuda(rank, non_blocking=True) + lengths = lengths.cuda(rank, non_blocking=True) + mel = spec_to_mel_torch( + spec, hps.data.filter_length, hps.data.n_mel_channels, hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, hps.data.mel_fmin, - hps.data.mel_fmax - ) - y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice - - # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - - with autocast(enabled=False, dtype=half_type): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) - loss_disc_all = loss_disc - - optim_d.zero_grad() - scaler.scale(loss_disc_all).backward() - scaler.unscale_(optim_d) - grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) - scaler.step(optim_d) - - - with autocast(enabled=hps.train.fp16_run, dtype=half_type): - # Generator - y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) - with autocast(enabled=False, dtype=half_type): - loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl - loss_fm = feature_loss(fmap_r, fmap_g) - loss_gen, losses_gen = generator_loss(y_d_hat_g) - loss_lf0 = F.mse_loss(pred_lf0, lf0) if net_g.module.use_automatic_f0_prediction else 0 - loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0 - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() - - if rank == 0: - if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]['lr'] - losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] - reference_loss=0 - for i in losses: - reference_loss += i - logger.info('Train Epoch: {} [{:.0f}%]'.format( - epoch, - 100. * batch_idx / len(train_loader))) - logger.info(f"Losses: {[x.item() for x in losses]}, step: {global_step}, lr: {lr}, reference_loss: {reference_loss}") - - scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, - "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} - scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, - "loss/g/lf0": loss_lf0}) - - # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) - # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) - # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), - "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), - "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()) - } - - if net_g.module.use_automatic_f0_prediction: - image_dict.update({ - "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), - pred_lf0[0, 0, :].detach().cpu().numpy()), - "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), - norm_lf0[0, 0, :].detach().cpu().numpy()) - }) - - utils.summarize( - writer=writer, - global_step=global_step, - images=image_dict, - scalars=scalar_dict + hps.data.mel_fmax) + + with autocast(enabled=hps.train.fp16_run, dtype=half_type): + y_hat, ids_slice, z_mask, \ + (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths, + spec_lengths=lengths,vol = volume) + + y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax ) - # 达到保存步数或者 stop 文件存在 - if global_step % hps.train.eval_interval == 0 or os.path.exists("stop.txt"): - if os.path.exists("stop.txt"): - logger.info("stop.txt found, stop training") - evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) - utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) - keep_ckpts = getattr(hps.train, 'keep_ckpts', 0) - if keep_ckpts > 0: - utils.clean_checkpoints(path_to_models=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) - if os.path.exists("stop.txt"): - logger.info("good bye!") - os.remove("stop.txt") - os._exit(0) - global_step += 1 + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + + with autocast(enabled=False, dtype=half_type): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc_all = loss_disc + + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + + with autocast(enabled=hps.train.fp16_run, dtype=half_type): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False, dtype=half_type): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_lf0 = F.mse_loss(pred_lf0, lf0) if net_g.module.use_automatic_f0_prediction else 0 + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0 + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]['lr'] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + reference_loss=0 + for i in losses: + reference_loss += i + logger.info('Train Epoch: {} [{:.0f}%]'.format( + epoch, + 100. * batch_idx / len(train_loader))) + logger.info(f"Losses: {[x.item() for x in losses]}, step: {global_step}, lr: {lr}, reference_loss: {reference_loss}") + + scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, + "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} + scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, + "loss/g/lf0": loss_lf0}) + + # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()) + } + + if net_g.module.use_automatic_f0_prediction: + image_dict.update({ + "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + pred_lf0[0, 0, :].detach().cpu().numpy()), + "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + norm_lf0[0, 0, :].detach().cpu().numpy()) + }) + + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict + ) + # 达到保存步数或者 stop 文件存在 + if global_step % hps.train.eval_interval == 0 or os.path.exists("stop.txt"): + if os.path.exists("stop.txt"): + logger.info("stop.txt found, stop training") + evaluate(hps, net_g, eval_loader, writer_eval) + utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) + utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + keep_ckpts = getattr(hps.train, 'keep_ckpts', 0) + if keep_ckpts > 0: + utils.clean_checkpoints(path_to_models=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) + if os.path.exists("stop.txt"): + logger.info("good bye!") + os.remove("stop.txt") + os._exit(0) + global_step += 1 if rank == 0: global start_time From 7080166af880b0f5cec49a0feac6f661a5cf4904 Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Sat, 11 Nov 2023 13:55:13 +0800 Subject: [PATCH 6/8] chore: more logger --- .gitignore | 2 +- logger/__init__.py | 3 ++- preprocess_hubert_f0.py | 2 +- resample.py | 17 +++++++++-------- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 07f692e4..415600fe 100644 --- a/.gitignore +++ b/.gitignore @@ -167,5 +167,5 @@ pretrain/ trained/**/ configs/ -filelists/ +filelists/* filelists/val.txt diff --git a/logger/__init__.py b/logger/__init__.py index fed6afa9..d512cf1d 100644 --- a/logger/__init__.py +++ b/logger/__init__.py @@ -39,7 +39,7 @@ def format_level(str, length): def default_format(record): # print(record) - return f"[green]{record['time'].strftime('%Y-%m-%d %H:%M:%S')}[/green] | [level]{format_level(record['level'].name,7)}[/level] | [cyan]{record['file'].path.replace(os.getcwd()+os.sep,'')}[/cyan] - [level]{record['message']}[/level]\n" + return f"[green]{record['time'].strftime('%Y-%m-%d %H:%M:%S')}[/green] | [level]{format_level(record['level'].name,7)}[/level] | [cyan]{record['file'].path.replace(os.getcwd()+os.sep,'')}:{record['line']}[/cyan] - [level]{record['message']}[/level]\n" logger.add(lambda m: console.print(m, end=""), format=default_format, colorize=True) @@ -58,6 +58,7 @@ def addLogger(path): info = logger.info error = logger.error warning = logger.warning +warn = logger.warning debug = logger.debug diff --git a/preprocess_hubert_f0.py b/preprocess_hubert_f0.py index 0c482104..edc75e71 100644 --- a/preprocess_hubert_f0.py +++ b/preprocess_hubert_f0.py @@ -10,7 +10,7 @@ import numpy as np import torch import torch.multiprocessing as mp -from loguru import logger +import logger from tqdm import tqdm import diffusion.logger.utils as du diff --git a/resample.py b/resample.py index af421fde..6c32259e 100644 --- a/resample.py +++ b/resample.py @@ -6,9 +6,9 @@ import librosa import numpy as np -from rich.progress import track from scipy.io import wavfile +import logger def load_wav(wav_path): return librosa.load(wav_path, sr=None) @@ -76,13 +76,14 @@ def process_all_speakers(): def process_all_speakers(): process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1) with ProcessPoolExecutor(max_workers=process_count) as executor: - for speaker in speakers: - spk_dir = os.path.join(args.in_dir, speaker) - if os.path.isdir(spk_dir): - print(spk_dir) - futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] - for _ in track(concurrent.futures.as_completed(futures), total=len(futures), description="resampling:"): - pass + with logger.Progress() as progress: + for speaker in speakers: + spk_dir = os.path.join(args.in_dir, speaker) + if os.path.isdir(spk_dir): + print(spk_dir) + futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] + for _ in progress.track(concurrent.futures.as_completed(futures), total=len(futures), description="resampling:"): + pass if __name__ == "__main__": From 1c234be8c5e646aa3547998607c94cc0952fe22f Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Sat, 11 Nov 2023 15:04:59 +0800 Subject: [PATCH 7/8] chore(train): logger --- logger/__init__.py | 4 +- train.py | 287 +++++++++++++++++++++++---------------------- 2 files changed, 149 insertions(+), 142 deletions(-) diff --git a/logger/__init__.py b/logger/__init__.py index d512cf1d..b6a59117 100644 --- a/logger/__init__.py +++ b/logger/__init__.py @@ -38,7 +38,7 @@ def format_level(str, length): return str def default_format(record): - # print(record) + print(record) return f"[green]{record['time'].strftime('%Y-%m-%d %H:%M:%S')}[/green] | [level]{format_level(record['level'].name,7)}[/level] | [cyan]{record['file'].path.replace(os.getcwd()+os.sep,'')}:{record['line']}[/cyan] - [level]{record['message']}[/level]\n" @@ -61,6 +61,8 @@ def addLogger(path): warn = logger.warning debug = logger.debug +def hps(hps): + console.print(hps) def Progress(): return _Progress( diff --git a/train.py b/train.py index e2f4b868..5393e150 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import multiprocessing import os import time +from rich import progress import torch import torch.distributed as dist @@ -48,7 +49,7 @@ def run(rank, n_gpus, hps): global global_step if rank == 0: logger = utils.get_logger(hps.model_dir) - logger.info(hps) + logger.hps(hps) utils.check_git_hash(hps.model_dir) writer = SummaryWriter(log_dir=hps.model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) @@ -112,27 +113,27 @@ def run(rank, n_gpus, hps): scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) scaler = GradScaler(enabled=hps.train.fp16_run) - - for epoch in range(epoch_str, hps.train.epochs + 1): - # set up warm-up learning rate - if epoch <= warmup_epoch: - for param_group in optim_g.param_groups: - param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch - for param_group in optim_d.param_groups: - param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch - # training - if rank == 0: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, - [train_loader, eval_loader], logger, [writer, writer_eval]) - else: - train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, - [train_loader, None], None, None) - # update learning rate - scheduler_g.step() - scheduler_d.step() + with logger.Progress() as progress: + for epoch in range(epoch_str, hps.train.epochs + 1): + # set up warm-up learning rate + if epoch <= warmup_epoch: + for param_group in optim_g.param_groups: + param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch + for param_group in optim_d.param_groups: + param_group['lr'] = hps.train.learning_rate / warmup_epoch * epoch + # training + if rank == 0: + train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, + [train_loader, eval_loader], logger, [writer, writer_eval], progress) + else: + train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, + [train_loader, None], None, None, progress) + # update learning rate + scheduler_g.step() + scheduler_d.step() -def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, progress: progress.Progress): net_g, net_d = nets optim_g, optim_d = optims scheduler_g, scheduler_d = schedulers @@ -148,135 +149,139 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade net_g.train() net_d.train() enumerated_train_loader = enumerate(train_loader) - with logger.Progress() as progress: - for batch_idx, items in progress.track(enumerated_train_loader, description="Epoch {}".format(epoch)): - c, f0, spec, y, spk, lengths, uv,volume = items - g = spk.cuda(rank, non_blocking=True) - spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) - c = c.cuda(rank, non_blocking=True) - f0 = f0.cuda(rank, non_blocking=True) - uv = uv.cuda(rank, non_blocking=True) - lengths = lengths.cuda(rank, non_blocking=True) - mel = spec_to_mel_torch( - spec, + # logger.info(f"enumerated_train_loader len: {len(enumerated_train_loader)}") + # "Epoch {}".format(epoch) + task = progress.add_task(f"Epoch {epoch}", total=len(train_loader)) + for batch_idx, items in enumerated_train_loader: + # logger.info(f"finish {progress.} ") + c, f0, spec, y, spk, lengths, uv,volume = items + g = spk.cuda(rank, non_blocking=True) + spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) + c = c.cuda(rank, non_blocking=True) + f0 = f0.cuda(rank, non_blocking=True) + uv = uv.cuda(rank, non_blocking=True) + lengths = lengths.cuda(rank, non_blocking=True) + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax) + + with autocast(enabled=hps.train.fp16_run, dtype=half_type): + y_hat, ids_slice, z_mask, \ + (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths, + spec_lengths=lengths,vol = volume) + + y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, hps.data.mel_fmin, - hps.data.mel_fmax) - - with autocast(enabled=hps.train.fp16_run, dtype=half_type): - y_hat, ids_slice, z_mask, \ - (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths, - spec_lengths=lengths,vol = volume) - - y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax - ) - y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice - - # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - - with autocast(enabled=False, dtype=half_type): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) - loss_disc_all = loss_disc - - optim_d.zero_grad() - scaler.scale(loss_disc_all).backward() - scaler.unscale_(optim_d) - grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) - scaler.step(optim_d) - - - with autocast(enabled=hps.train.fp16_run, dtype=half_type): - # Generator - y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) - with autocast(enabled=False, dtype=half_type): - loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel - loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl - loss_fm = feature_loss(fmap_r, fmap_g) - loss_gen, losses_gen = generator_loss(y_d_hat_g) - loss_lf0 = F.mse_loss(pred_lf0, lf0) if net_g.module.use_automatic_f0_prediction else 0 - loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0 - optim_g.zero_grad() - scaler.scale(loss_gen_all).backward() - scaler.unscale_(optim_g) - grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) - scaler.step(optim_g) - scaler.update() - - if rank == 0: - if global_step % hps.train.log_interval == 0: - lr = optim_g.param_groups[0]['lr'] - losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] - reference_loss=0 - for i in losses: - reference_loss += i - logger.info('Train Epoch: {} [{:.0f}%]'.format( - epoch, - 100. * batch_idx / len(train_loader))) - logger.info(f"Losses: {[x.item() for x in losses]}, step: {global_step}, lr: {lr}, reference_loss: {reference_loss}") - - scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, - "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} - scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, - "loss/g/lf0": loss_lf0}) - - # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) - # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) - # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), - "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), - "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()) - } - - if net_g.module.use_automatic_f0_prediction: - image_dict.update({ - "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), - pred_lf0[0, 0, :].detach().cpu().numpy()), - "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), - norm_lf0[0, 0, :].detach().cpu().numpy()) - }) - - utils.summarize( - writer=writer, - global_step=global_step, - images=image_dict, - scalars=scalar_dict - ) - # 达到保存步数或者 stop 文件存在 - if global_step % hps.train.eval_interval == 0 or os.path.exists("stop.txt"): - if os.path.exists("stop.txt"): - logger.info("stop.txt found, stop training") - evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) - utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) - keep_ckpts = getattr(hps.train, 'keep_ckpts', 0) - if keep_ckpts > 0: - utils.clean_checkpoints(path_to_models=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) - if os.path.exists("stop.txt"): - logger.info("good bye!") - os.remove("stop.txt") - os._exit(0) - global_step += 1 + hps.data.mel_fmax + ) + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + + with autocast(enabled=False, dtype=half_type): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc_all = loss_disc + + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + + with autocast(enabled=hps.train.fp16_run, dtype=half_type): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False, dtype=half_type): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_lf0 = F.mse_loss(pred_lf0, lf0) if net_g.module.use_automatic_f0_prediction else 0 + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0 + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]['lr'] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + reference_loss=0 + for i in losses: + reference_loss += i + logger.info('Train Epoch: {} [{:.0f}%]'.format( + epoch, + 100. * batch_idx / len(train_loader))) + logger.info(f"Losses: {[x.item() for x in losses]}, step: {global_step}, lr: {lr}, reference_loss: {reference_loss}") + + scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, + "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} + scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, + "loss/g/lf0": loss_lf0}) + + # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()) + } + + if net_g.module.use_automatic_f0_prediction: + image_dict.update({ + "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + pred_lf0[0, 0, :].detach().cpu().numpy()), + "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(), + norm_lf0[0, 0, :].detach().cpu().numpy()) + }) + + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict + ) + # 达到保存步数或者 stop 文件存在 + if global_step % hps.train.eval_interval == 0 or os.path.exists("stop.txt"): + if os.path.exists("stop.txt"): + logger.info("stop.txt found, stop training") + evaluate(hps, net_g, eval_loader, writer_eval) + utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) + utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + keep_ckpts = getattr(hps.train, 'keep_ckpts', 0) + if keep_ckpts > 0: + utils.clean_checkpoints(path_to_models=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) + if os.path.exists("stop.txt"): + logger.info("good bye!") + os.remove("stop.txt") + os._exit(0) + global_step += 1 + progress.advance(task) + progress.remove_task(task) if rank == 0: global start_time now = time.time() - durtaion = format(now - start_time, '.2f') - logger.info(f'====> Epoch: {epoch}, cost {durtaion} s') + duration = format(now - start_time, '.2f') # 这里原本是 durtaion,让我看看是谁拼错了( + logger.info(f'Epoch: {epoch} finished, cost {duration} s, {round(float(duration)/len(train_loader),3)}s per batch') start_time = now From c797641be8e6dc90b912862f0c70002cd105652d Mon Sep 17 00:00:00 2001 From: HuanLinOTO Date: Sat, 11 Nov 2023 20:46:51 +0800 Subject: [PATCH 8/8] chore(preprocess): logger --- .gitignore | 2 ++ logger/__init__.py | 3 +- preprocess_hubert_f0.py | 50 +++++++++++++++++++++---------- test.py | 21 ------------- utils.py | 28 ++++++++--------- vencoder/CNHubertLarge.py | 5 ++-- vencoder/ContentVec256L12_Onnx.py | 5 ++-- vencoder/ContentVec256L9.py | 5 ++-- vencoder/ContentVec256L9_Onnx.py | 5 ++-- vencoder/ContentVec768L12.py | 5 ++-- vencoder/ContentVec768L12_Onnx.py | 5 ++-- vencoder/DPHubert.py | 5 ++-- vencoder/HubertSoft.py | 5 ++-- vencoder/HubertSoft_Onnx.py | 5 ++-- vencoder/WavLMBasePlus.py | 5 ++-- 15 files changed, 81 insertions(+), 73 deletions(-) delete mode 100644 test.py diff --git a/.gitignore b/.gitignore index 415600fe..f2f83601 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,5 @@ trained/**/ configs/ filelists/* filelists/val.txt +test_pipe.py +test_queue.py diff --git a/logger/__init__.py b/logger/__init__.py index b6a59117..e6cf5e9b 100644 --- a/logger/__init__.py +++ b/logger/__init__.py @@ -38,7 +38,6 @@ def format_level(str, length): return str def default_format(record): - print(record) return f"[green]{record['time'].strftime('%Y-%m-%d %H:%M:%S')}[/green] | [level]{format_level(record['level'].name,7)}[/level] | [cyan]{record['file'].path.replace(os.getcwd()+os.sep,'')}:{record['line']}[/cyan] - [level]{record['message']}[/level]\n" @@ -74,4 +73,4 @@ def Progress(): TextColumn("[red]*Elapsed[/red]"), TimeElapsedColumn(), console=console, - ) + ) \ No newline at end of file diff --git a/preprocess_hubert_f0.py b/preprocess_hubert_f0.py index edc75e71..f1a94b81 100644 --- a/preprocess_hubert_f0.py +++ b/preprocess_hubert_f0.py @@ -1,5 +1,6 @@ import argparse import logging +from multiprocessing import Pipe import os import random from concurrent.futures import ProcessPoolExecutor @@ -103,30 +104,47 @@ def process_one(filename, hmodel, f0p, device, diff=False, mel_extractor=None): np.save(aug_vol_path,aug_vol.to('cpu').numpy()) -def process_batch(file_chunk, f0p, diff=False, mel_extractor=None, device="cpu"): - logger.info("Loading speech encoder for content...") +def process_batch(pipe, file_chunk, f0p, diff=False, mel_extractor=None, device="cpu"): + # logger.info("Loading speech encoder for content...") rank = mp.current_process()._identity rank = rank[0] if len(rank) > 0 else 0 if torch.cuda.is_available(): gpu_id = rank % torch.cuda.device_count() device = torch.device(f"cuda:{gpu_id}") - logger.info(f"Rank {rank} uses device {device}") - hmodel = utils.get_speech_encoder(speech_encoder, device=device) - logger.info(f"Loaded speech encoder for rank {rank}") - for filename in tqdm(file_chunk, position = rank): + # logger.info(f"Rank {rank} uses device {device}") + hmodel = utils.get_speech_encoder(speech_encoder, device=device, log=False) + # logger.info(f"Loaded speech encoder for rank {rank}") + # for filename in tqdm(file_chunk, position = rank): + # process_one(filename, hmodel, f0p, device, diff, mel_extractor) + for filename in file_chunk: process_one(filename, hmodel, f0p, device, diff, mel_extractor) + pipe.send(1) + # logger.info(f"Rank {rank} finished {filename}") def parallel_process(filenames, num_processes, f0p, diff, mel_extractor, device): - with ProcessPoolExecutor(max_workers=num_processes) as executor: - tasks = [] - for i in range(num_processes): - start = int(i * len(filenames) / num_processes) - end = int((i + 1) * len(filenames) / num_processes) - file_chunk = filenames[start:end] - tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor, device=device)) - for task in tqdm(tasks, position = 0): - task.result() - + with logger.Progress() as progress: + task = progress.add_task("Preprocessing", total=len(filenames)) + with ProcessPoolExecutor(max_workers=num_processes) as executor: + tasks = [] + pipes = [] + for i in range(num_processes): + start = int(i * len(filenames) / num_processes) + end = int((i + 1) * len(filenames) / num_processes) + file_chunk = filenames[start:end] + parent_conn, child_conn = Pipe() + pipes.append((parent_conn, child_conn)) + tasks.append(executor.submit(process_batch, child_conn, file_chunk, f0p, diff, mel_extractor, device=device)) + while True: + # if all([task.done() for task in tasks]): + # break + for parent_conn, child_conn in pipes: + # if not parent_conn.empty(): + parent_conn.recv() + # logger.debug("Received message") + progress.advance(task) + # progress.refresh() + if progress.finished: + break if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-d', '--device', type=str, default=None) diff --git a/test.py b/test.py deleted file mode 100644 index 4a46abab..00000000 --- a/test.py +++ /dev/null @@ -1,21 +0,0 @@ -import time -import logger - -def work(i): - # logger.info("[red]Working[/red]: {}", i) - # time.sleep(0.) - pass - -def main(): - logger.info("Before") - with logger.Progress() as progress: - for i in progress.track(range(100), description="Workkkkkking"): - work(i) - logger.info("After") - logger.warning("Warning") - logger.error("Error") - logger.debug("Debug") - logger.info("Info") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/utils.py b/utils.py index ca1f349c..496a20f6 100644 --- a/utils.py +++ b/utils.py @@ -110,46 +110,46 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs): raise Exception("Unknown f0 predictor") return f0_predictor_object -def get_speech_encoder(speech_encoder,device=None,**kargs): +def get_speech_encoder(speech_encoder,device=None,log=True,**kargs): if speech_encoder == "vec768l12": from vencoder.ContentVec768L12 import ContentVec768L12 - speech_encoder_object = ContentVec768L12(device = device) + speech_encoder_object = ContentVec768L12(device = device,log=log) elif speech_encoder == "vec256l9": from vencoder.ContentVec256L9 import ContentVec256L9 - speech_encoder_object = ContentVec256L9(device = device) + speech_encoder_object = ContentVec256L9(device = device,log=log) elif speech_encoder == "vec256l9-onnx": from vencoder.ContentVec256L9_Onnx import ContentVec256L9_Onnx - speech_encoder_object = ContentVec256L9_Onnx(device = device) + speech_encoder_object = ContentVec256L9_Onnx(device = device,log=log) elif speech_encoder == "vec256l12-onnx": from vencoder.ContentVec256L12_Onnx import ContentVec256L12_Onnx - speech_encoder_object = ContentVec256L12_Onnx(device = device) + speech_encoder_object = ContentVec256L12_Onnx(device = device,log=log) elif speech_encoder == "vec768l9-onnx": from vencoder.ContentVec768L9_Onnx import ContentVec768L9_Onnx - speech_encoder_object = ContentVec768L9_Onnx(device = device) + speech_encoder_object = ContentVec768L9_Onnx(device = device,log=log) elif speech_encoder == "vec768l12-onnx": from vencoder.ContentVec768L12_Onnx import ContentVec768L12_Onnx - speech_encoder_object = ContentVec768L12_Onnx(device = device) + speech_encoder_object = ContentVec768L12_Onnx(device = device,log=log) elif speech_encoder == "hubertsoft-onnx": from vencoder.HubertSoft_Onnx import HubertSoft_Onnx - speech_encoder_object = HubertSoft_Onnx(device = device) + speech_encoder_object = HubertSoft_Onnx(device = device,log=log) elif speech_encoder == "hubertsoft": from vencoder.HubertSoft import HubertSoft - speech_encoder_object = HubertSoft(device = device) + speech_encoder_object = HubertSoft(device = device,log=log) elif speech_encoder == "whisper-ppg": from vencoder.WhisperPPG import WhisperPPG - speech_encoder_object = WhisperPPG(device = device) + speech_encoder_object = WhisperPPG(device = device,log=log) elif speech_encoder == "cnhubertlarge": from vencoder.CNHubertLarge import CNHubertLarge - speech_encoder_object = CNHubertLarge(device = device) + speech_encoder_object = CNHubertLarge(device = device,log=log) elif speech_encoder == "dphubert": from vencoder.DPHubert import DPHubert - speech_encoder_object = DPHubert(device = device) + speech_encoder_object = DPHubert(device = device,log=log) elif speech_encoder == "whisper-ppg-large": from vencoder.WhisperPPGLarge import WhisperPPGLarge - speech_encoder_object = WhisperPPGLarge(device = device) + speech_encoder_object = WhisperPPGLarge(device = device,log=log) elif speech_encoder == "wavlmbase+": from vencoder.WavLMBasePlus import WavLMBasePlus - speech_encoder_object = WavLMBasePlus(device = device) + speech_encoder_object = WavLMBasePlus(device = device,log=log) else: raise Exception("Unknown speech encoder") return speech_encoder_object diff --git a/vencoder/CNHubertLarge.py b/vencoder/CNHubertLarge.py index fdaa0c20..1ce0c4d5 100644 --- a/vencoder/CNHubertLarge.py +++ b/vencoder/CNHubertLarge.py @@ -5,10 +5,11 @@ class CNHubertLarge(SpeechEncoder): - def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None): + def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 1024 models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], diff --git a/vencoder/ContentVec256L12_Onnx.py b/vencoder/ContentVec256L12_Onnx.py index 601da1ea..3ad1163b 100644 --- a/vencoder/ContentVec256L12_Onnx.py +++ b/vencoder/ContentVec256L12_Onnx.py @@ -5,10 +5,11 @@ class ContentVec256L12_Onnx(SpeechEncoder): - def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None): + def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/ContentVec256L9.py b/vencoder/ContentVec256L9.py index dfcab8b1..1090fe2b 100644 --- a/vencoder/ContentVec256L9.py +++ b/vencoder/ContentVec256L9.py @@ -5,10 +5,11 @@ class ContentVec256L9(SpeechEncoder): - def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], suffix="", diff --git a/vencoder/ContentVec256L9_Onnx.py b/vencoder/ContentVec256L9_Onnx.py index 78173da4..06980667 100644 --- a/vencoder/ContentVec256L9_Onnx.py +++ b/vencoder/ContentVec256L9_Onnx.py @@ -5,10 +5,11 @@ class ContentVec256L9_Onnx(SpeechEncoder): - def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None): + def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/ContentVec768L12.py b/vencoder/ContentVec768L12.py index ad06a979..4772f43b 100644 --- a/vencoder/ContentVec768L12.py +++ b/vencoder/ContentVec768L12.py @@ -5,10 +5,11 @@ class ContentVec768L12(SpeechEncoder): - def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [vec_path], diff --git a/vencoder/ContentVec768L12_Onnx.py b/vencoder/ContentVec768L12_Onnx.py index cf4e2705..8b7862a0 100644 --- a/vencoder/ContentVec768L12_Onnx.py +++ b/vencoder/ContentVec768L12_Onnx.py @@ -5,10 +5,11 @@ class ContentVec768L12_Onnx(SpeechEncoder): - def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None): + def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 768 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/DPHubert.py b/vencoder/DPHubert.py index 540eb58c..7077753c 100644 --- a/vencoder/DPHubert.py +++ b/vencoder/DPHubert.py @@ -5,10 +5,11 @@ class DPHubert(SpeechEncoder): - def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None): + def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: diff --git a/vencoder/HubertSoft.py b/vencoder/HubertSoft.py index b3b01442..3af0a935 100644 --- a/vencoder/HubertSoft.py +++ b/vencoder/HubertSoft.py @@ -5,10 +5,11 @@ class HubertSoft(SpeechEncoder): - def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None): + def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) hubert_soft = hubert_model.hubert_soft(vec_path) if device is None: self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/vencoder/HubertSoft_Onnx.py b/vencoder/HubertSoft_Onnx.py index 52fadf95..bb8d507d 100644 --- a/vencoder/HubertSoft_Onnx.py +++ b/vencoder/HubertSoft_Onnx.py @@ -5,10 +5,11 @@ class HubertSoft_Onnx(SpeechEncoder): - def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None): + def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) self.hidden_dim = 256 if device is None: self.dev = torch.device("cpu") diff --git a/vencoder/WavLMBasePlus.py b/vencoder/WavLMBasePlus.py index 8dc82592..2a8d3c74 100644 --- a/vencoder/WavLMBasePlus.py +++ b/vencoder/WavLMBasePlus.py @@ -5,10 +5,11 @@ class WavLMBasePlus(SpeechEncoder): - def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None): + def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None, log=True): super().__init__() import logger - logger.info("load model(s) from {}".format(vec_path)) + if log: + logger.info("load model(s) from {}".format(vec_path)) checkpoint = torch.load(vec_path) self.cfg = WavLMConfig(checkpoint['cfg']) if device is None: