Skip to content

Commit

Permalink
Optimising hparams
Browse files Browse the repository at this point in the history
  • Loading branch information
Vidalnt committed Jul 27, 2024
1 parent 1cd515c commit c317859
Showing 1 changed file with 34 additions and 102 deletions.
136 changes: 34 additions & 102 deletions rvc/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import glob
import json
import torch
import argparse
import sys
import numpy as np
from scipy.io.wavfile import read
from collections import OrderedDict
Expand Down Expand Up @@ -245,113 +245,45 @@ def get_hparams():
"""
Parses command line arguments and loads hyperparameters from a configuration file.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_every_epoch",
type=str,
help="Frequency (in epochs) at which checkpoints are saved.",
)
parser.add_argument(
"--total_epoch",
type=str,
help="Total number of training epochs.",
)
parser.add_argument(
"--pretrainG",
type=str,
help="Path to the pretrained Generator model.",
)
parser.add_argument(
"--pretrainD",
type=str,
help="Path to the pretrained Discriminator model.",
)
parser.add_argument(
"--gpus",
type=str,
help="Hyphen-separated list of GPU device IDs to use (e.g., '0-1-2').",
)
parser.add_argument(
"--batch_size",
type=str,
help="Batch size for training.",
)
parser.add_argument(
"--experiment_dir",
type=str,
help="Directory to store experiment outputs.",
)
parser.add_argument(
"--sample_rate",
type=str,
help="Sample rate to use.",
)
parser.add_argument(
"--save_every_weights",
type=str,
help="Save the model weights in the weights directory when saving checkpoints.",
)
parser.add_argument(
"--version",
type=str,
help="Model version identifier.",
)
parser.add_argument(
"--pitch_guidance",
type=str,
help="Use pitch (f0) as one of the inputs to the model (True or False).",
)
parser.add_argument(
"--if_latest",
type=str,
help="Only save the latest Generator/Discriminator model files (True or False).",
)
parser.add_argument(
"--if_cache_data_in_gpu",
type=str,
help="Cache the dataset in GPU memory (True or False).",
)
parser.add_argument(
"--overtraining_detector",
type=str,
help="Detect overtraining (True or False).",
)
parser.add_argument(
"--overtraining_threshold",
type=str,
help="Threshold for overtraining detection.",
)
parser.add_argument(
"--sync_graph",
type=str,
help="Synchronize graph (True or False).",
)

args = parser.parse_args()
name = args.experiment_dir
experiment_dir = os.path.join("./logs", args.experiment_dir)
args = {}
for i in range(1, len(sys.argv), 2):
if i + 1 < len(sys.argv):
key = sys.argv[i].replace('--', '')
value = sys.argv[i + 1]
if value.isdigit():
args[key] = int(value)
elif value.lower() == 'true':
args[key] = True
elif value.lower() == 'false':
args[key] = False
else:
args[key] = value

experiment_dir = os.path.join("./logs", args["experiment_dir"])
config_save_path = os.path.join(experiment_dir, "config.json")
with open(config_save_path, "r") as f:
config = json.load(f)

hparams = HParams(**config)
hparams.model_dir = hparams.experiment_dir = experiment_dir
hparams.save_every_epoch = int(args.save_every_epoch)
hparams.name = name
hparams.total_epoch = int(args.total_epoch)
hparams.pretrainG = args.pretrainG
hparams.pretrainD = args.pretrainD
hparams.version = args.version
hparams.gpus = args.gpus
hparams.batch_size = int(args.batch_size)
hparams.sample_rate = int(args.sample_rate)
hparams.pitch_guidance = args.pitch_guidance
hparams.if_latest = bool(args.if_latest)
hparams.save_every_weights = bool(args.save_every_weights)
hparams.if_cache_data_in_gpu = bool(args.if_cache_data_in_gpu)
hparams.save_every_epoch = args.get("save_every_epoch")
hparams.name = args.get("experiment_dir")
hparams.total_epoch = args.get("total_epoch")
hparams.pretrainG = args.get("pretrainG")
hparams.pretrainD = args.get("pretrainD")
hparams.version = args.get("version")
hparams.gpus = args.get("gpus")
hparams.batch_size = args.get("batch_size")
hparams.sample_rate = args.get("sample_rate")
hparams.pitch_guidance = args.get("pitch_guidance")
hparams.if_latest = args.get("if_latest")
hparams.save_every_weights = args.get("save_every_weights")
hparams.if_cache_data_in_gpu = args.get("if_cache_data_in_gpu")
hparams.data.training_files = f"{experiment_dir}/filelist.txt"
hparams.overtraining_detector = bool(args.overtraining_detector)
hparams.overtraining_threshold = int(args.overtraining_threshold)
hparams.sync_graph = args.sync_graph
hparams.overtraining_detector = args.get("overtraining_detector")
hparams.overtraining_threshold = args.get("overtraining_threshold")
hparams.sync_graph = args.get("sync_graph")

print(hparams)
return hparams

Expand Down

0 comments on commit c317859

Please sign in to comment.