From 4c92605268eebf3a6b307cc5d5de6d7b0024f667 Mon Sep 17 00:00:00 2001 From: sml2h3 Date: Mon, 28 Feb 2022 11:22:40 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E4=B8=80=E5=A0=86bu?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nets/__init__.py | 26 ++++++++++++++++++-------- utils/load_cache.py | 16 ++++++---------- utils/train.py | 39 +++++++++++++++++++++++---------------- 3 files changed, 47 insertions(+), 34 deletions(-) diff --git a/nets/__init__.py b/nets/__init__.py index 6cb8199f..ae55401a 100644 --- a/nets/__init__.py +++ b/nets/__init__.py @@ -12,7 +12,7 @@ class Net(torch.nn.Module): - def __init__(self, conf): + def __init__(self, conf, lr=None): super(Net, self).__init__() self.backbones_list = { @@ -75,7 +75,10 @@ def __init__(self, conf): self.paramters.append({'params': self.fc.parameters()}) - self.lr = self.conf['Train']['LR'] + if lr == None: + self.lr = self.conf['Train']['LR'] + else: + self.lr = lr self.optim = self.conf['Train']['OPTIMIZER'] if self.optim in self.optimizers_list: @@ -89,6 +92,7 @@ def __init__(self, conf): self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98) + def forward(self, inputs): predict = self.get_features(inputs) if self.word: @@ -183,7 +187,8 @@ def get_loss(self, predict, labels, labels_length): def save_model(self, path, net): torch.save(net, path) - def get_device(self, gpu_id): + @staticmethod + def get_device(gpu_id): if gpu_id == -1: device = torch.device('cpu'.format(str(gpu_id))) else: @@ -212,10 +217,15 @@ def export_onnx(self, net, dummy_input, graph_path, input_names, output_names, d input_names=input_names, output_names=output_names, dynamic_axes=dynamic_ax, opset_version=12, do_constant_folding=True, _retain_param_name=False) - def load_checkpoint(self, path): - param = torch.load(path) + + @staticmethod + def load_checkpoint(path, device): + param = torch.load(path, map_location=device) state_dict = param['net'] optimizer = param['optimizer'] - self.load_state_dict(state_dict) - self.optimizer.load_state_dict(optimizer) - return param['epoch'], param['step'], param['lr'] + # self.lr = param['lr'] + # self.reset_optimizer(param['epoch']) + # self.load_state_dict(state_dict) + # self.optimizer.load_state_dict(optimizer) + # return param['epoch'], param['step'], param['lr'] + return param, state_dict, optimizer diff --git a/utils/load_cache.py b/utils/load_cache.py index ed1f7ef1..7499ebbf 100644 --- a/utils/load_cache.py +++ b/utils/load_cache.py @@ -2,7 +2,6 @@ import os import torch -import tqdm from configs import Config from loguru import logger @@ -22,17 +21,11 @@ def __init__(self, cache_path: str, path: str, word: bool, image_channel: int, r self.ImageChannel = image_channel self.resize = resize self.charset = charset - + self.caches = [] logger.info("\nReading Cache File... ----> {}".format(self.cache_path)) with open(self.cache_path, 'r', encoding='utf-8') as f: - caches = f.readlines() - self.caches = [] - for cache in tqdm.tqdm(caches): - cache = cache.replace("\r", "").replace("\n", "").split("\t") - self.caches.append(cache) - del caches - + self.caches = f.readlines() self.caches_num = len(self.caches) logger.info("\nRead Cache File End! Caches Num is {}.".format(self.caches_num)) @@ -42,6 +35,7 @@ def __len__(self): def __getitem__(self, idx): try: data = self.caches[idx] + data = data.replace("\r", "").replace("\n", "").split("\t") image_name = data[0] image_label = data[1] image_path = os.path.join(self.path, image_name) @@ -70,7 +64,7 @@ def __getitem__(self, idx): return image, label except Exception as e: - logger.error("\nError: {}, File: {}".format(str(e), self.caches[idx][0])) + logger.error("\nError: {}, File: {}".format(str(e), self.caches[idx].split("\t")[0])) return None, None @@ -148,6 +142,8 @@ def __init__(self, project_name: str): 'val': DataLoader(dataset=val_loader, batch_size=self.val_batch_size, shuffle=True, drop_last=True, num_workers=0, collate_fn=self.collate_to_sparse), } + del val_loader + del train_loader def collate_to_sparse(self, batch): values = [] diff --git a/utils/train.py b/utils/train.py index 04fa0aa7..67f6a2f8 100644 --- a/utils/train.py +++ b/utils/train.py @@ -20,6 +20,9 @@ def __init__(self, project_name: str): self.models_path = os.path.join(self.project_path, "models") self.epoch = 0 self.step = 0 + self.lr = None + self.state_dict = None + self.optimizer = None self.config = Config(project_name) self.conf = self.config.load_config() @@ -36,23 +39,16 @@ def __init__(self, project_name: str): self.ImageChannel = self.conf['Model']['ImageChannel'] logger.info("\nTaget:\nmin_Accuracy: {}\nmin_Epoch: {}\nmax_Loss: {}".format(self.target_acc, self.min_epoch, self.max_loss)) - - logger.info("\nBuilding Net...") - self.net = Net(self.conf) - logger.info(self.net) - logger.info("\nBuilding End") - self.use_gpu = self.conf['System']['GPU'] if self.use_gpu: self.gpu_id = self.conf['System']['GPU_ID'] logger.info("\nUSE GPU ----> {}".format(self.gpu_id)) - self.device = self.net.get_device(self.gpu_id) - self.net.to(self.device) + self.device = Net.get_device(self.gpu_id) + else: self.gpu_id = -1 - self.device = self.net.get_device(self.gpu_id) + self.device = Net.get_device(self.gpu_id) logger.info("\nUSE CPU".format(self.gpu_id)) - logger.info("\nSearch for history checkpoints...") history_checkpoints = os.listdir(self.checkpoints_path) if len(history_checkpoints) > 0: @@ -60,25 +56,36 @@ def __init__(self, project_name: str): newer_checkpoint = None for checkpoint in history_checkpoints: checkpoint_name = checkpoint.split(".")[0].split("_") - if int(checkpoint_name[2]) > history_step: + if int(checkpoint_name[3]) > history_step: newer_checkpoint = checkpoint - history_step = int(checkpoint_name[2]) - self.epoch, self.step, self.lr = self.net.load_checkpoint( - os.path.join(self.checkpoints_path, newer_checkpoint)) + history_step = int(checkpoint_name[3]) + param, self.state_dict, self.optimizer= Net.load_checkpoint( + os.path.join(self.checkpoints_path, newer_checkpoint), self.device) + self.epoch, self.step, self.lr = param['epoch'], param['step'], param['lr'] self.epoch += 1 self.step += 1 - self.net.lr = self.lr else: logger.info("\nEmpty history checkpoints") + + logger.info("\nBuilding Net...") + self.net = Net(self.conf, self.lr) + if self.state_dict: + self.net.load_state_dict(self.state_dict) + logger.info(self.net) + logger.info("\nBuilding End") + + + + self.net = self.net.to(self.device) logger.info("\nGet Data Loader...") + loaders = load_cache.GetLoader(project_name) self.train = loaders.loaders['train'] self.val = loaders.loaders['val'] del loaders logger.info("\nGet Data Loader End!") - self.loss = 0 self.avg_loss = 0 self.start_time = time.time()