From c31785947583863f10fd0d0e06cb33b189fa74ed Mon Sep 17 00:00:00 2001 From: Vidalnt Date: Sat, 27 Jul 2024 18:07:16 -0500 Subject: [PATCH] Optimising hparams --- rvc/train/utils.py | 136 ++++++++++++--------------------------------- 1 file changed, 34 insertions(+), 102 deletions(-) diff --git a/rvc/train/utils.py b/rvc/train/utils.py index 17defa0..d01f23d 100644 --- a/rvc/train/utils.py +++ b/rvc/train/utils.py @@ -2,7 +2,7 @@ import glob import json import torch -import argparse +import sys import numpy as np from scipy.io.wavfile import read from collections import OrderedDict @@ -245,113 +245,45 @@ def get_hparams(): """ Parses command line arguments and loads hyperparameters from a configuration file. """ - parser = argparse.ArgumentParser() - parser.add_argument( - "--save_every_epoch", - type=str, - help="Frequency (in epochs) at which checkpoints are saved.", - ) - parser.add_argument( - "--total_epoch", - type=str, - help="Total number of training epochs.", - ) - parser.add_argument( - "--pretrainG", - type=str, - help="Path to the pretrained Generator model.", - ) - parser.add_argument( - "--pretrainD", - type=str, - help="Path to the pretrained Discriminator model.", - ) - parser.add_argument( - "--gpus", - type=str, - help="Hyphen-separated list of GPU device IDs to use (e.g., '0-1-2').", - ) - parser.add_argument( - "--batch_size", - type=str, - help="Batch size for training.", - ) - parser.add_argument( - "--experiment_dir", - type=str, - help="Directory to store experiment outputs.", - ) - parser.add_argument( - "--sample_rate", - type=str, - help="Sample rate to use.", - ) - parser.add_argument( - "--save_every_weights", - type=str, - help="Save the model weights in the weights directory when saving checkpoints.", - ) - parser.add_argument( - "--version", - type=str, - help="Model version identifier.", - ) - parser.add_argument( - "--pitch_guidance", - type=str, - help="Use pitch (f0) as one of the inputs to the model (True or False).", - ) - parser.add_argument( - "--if_latest", - type=str, - help="Only save the latest Generator/Discriminator model files (True or False).", - ) - parser.add_argument( - "--if_cache_data_in_gpu", - type=str, - help="Cache the dataset in GPU memory (True or False).", - ) - parser.add_argument( - "--overtraining_detector", - type=str, - help="Detect overtraining (True or False).", - ) - parser.add_argument( - "--overtraining_threshold", - type=str, - help="Threshold for overtraining detection.", - ) - parser.add_argument( - "--sync_graph", - type=str, - help="Synchronize graph (True or False).", - ) - - args = parser.parse_args() - name = args.experiment_dir - experiment_dir = os.path.join("./logs", args.experiment_dir) + args = {} + for i in range(1, len(sys.argv), 2): + if i + 1 < len(sys.argv): + key = sys.argv[i].replace('--', '') + value = sys.argv[i + 1] + if value.isdigit(): + args[key] = int(value) + elif value.lower() == 'true': + args[key] = True + elif value.lower() == 'false': + args[key] = False + else: + args[key] = value + + experiment_dir = os.path.join("./logs", args["experiment_dir"]) config_save_path = os.path.join(experiment_dir, "config.json") with open(config_save_path, "r") as f: config = json.load(f) + hparams = HParams(**config) hparams.model_dir = hparams.experiment_dir = experiment_dir - hparams.save_every_epoch = int(args.save_every_epoch) - hparams.name = name - hparams.total_epoch = int(args.total_epoch) - hparams.pretrainG = args.pretrainG - hparams.pretrainD = args.pretrainD - hparams.version = args.version - hparams.gpus = args.gpus - hparams.batch_size = int(args.batch_size) - hparams.sample_rate = int(args.sample_rate) - hparams.pitch_guidance = args.pitch_guidance - hparams.if_latest = bool(args.if_latest) - hparams.save_every_weights = bool(args.save_every_weights) - hparams.if_cache_data_in_gpu = bool(args.if_cache_data_in_gpu) + hparams.save_every_epoch = args.get("save_every_epoch") + hparams.name = args.get("experiment_dir") + hparams.total_epoch = args.get("total_epoch") + hparams.pretrainG = args.get("pretrainG") + hparams.pretrainD = args.get("pretrainD") + hparams.version = args.get("version") + hparams.gpus = args.get("gpus") + hparams.batch_size = args.get("batch_size") + hparams.sample_rate = args.get("sample_rate") + hparams.pitch_guidance = args.get("pitch_guidance") + hparams.if_latest = args.get("if_latest") + hparams.save_every_weights = args.get("save_every_weights") + hparams.if_cache_data_in_gpu = args.get("if_cache_data_in_gpu") hparams.data.training_files = f"{experiment_dir}/filelist.txt" - hparams.overtraining_detector = bool(args.overtraining_detector) - hparams.overtraining_threshold = int(args.overtraining_threshold) - hparams.sync_graph = args.sync_graph + hparams.overtraining_detector = args.get("overtraining_detector") + hparams.overtraining_threshold = args.get("overtraining_threshold") + hparams.sync_graph = args.get("sync_graph") + print(hparams) return hparams