Skip to content

Commit

Permalink
修复了一堆bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sml2h3 committed Feb 28, 2022
1 parent 948c304 commit 4c92605
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 34 deletions.
26 changes: 18 additions & 8 deletions nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class Net(torch.nn.Module):
def __init__(self, conf):
def __init__(self, conf, lr=None):
super(Net, self).__init__()

self.backbones_list = {
Expand Down Expand Up @@ -75,7 +75,10 @@ def __init__(self, conf):

self.paramters.append({'params': self.fc.parameters()})

self.lr = self.conf['Train']['LR']
if lr == None:
self.lr = self.conf['Train']['LR']
else:
self.lr = lr

self.optim = self.conf['Train']['OPTIMIZER']
if self.optim in self.optimizers_list:
Expand All @@ -89,6 +92,7 @@ def __init__(self, conf):

self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98)


def forward(self, inputs):
predict = self.get_features(inputs)
if self.word:
Expand Down Expand Up @@ -183,7 +187,8 @@ def get_loss(self, predict, labels, labels_length):
def save_model(self, path, net):
torch.save(net, path)

def get_device(self, gpu_id):
@staticmethod
def get_device(gpu_id):
if gpu_id == -1:
device = torch.device('cpu'.format(str(gpu_id)))
else:
Expand Down Expand Up @@ -212,10 +217,15 @@ def export_onnx(self, net, dummy_input, graph_path, input_names, output_names, d
input_names=input_names, output_names=output_names, dynamic_axes=dynamic_ax,
opset_version=12, do_constant_folding=True, _retain_param_name=False)

def load_checkpoint(self, path):
param = torch.load(path)

@staticmethod
def load_checkpoint(path, device):
param = torch.load(path, map_location=device)
state_dict = param['net']
optimizer = param['optimizer']
self.load_state_dict(state_dict)
self.optimizer.load_state_dict(optimizer)
return param['epoch'], param['step'], param['lr']
# self.lr = param['lr']
# self.reset_optimizer(param['epoch'])
# self.load_state_dict(state_dict)
# self.optimizer.load_state_dict(optimizer)
# return param['epoch'], param['step'], param['lr']
return param, state_dict, optimizer
16 changes: 6 additions & 10 deletions utils/load_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

import torch
import tqdm

from configs import Config
from loguru import logger
Expand All @@ -22,17 +21,11 @@ def __init__(self, cache_path: str, path: str, word: bool, image_channel: int, r
self.ImageChannel = image_channel
self.resize = resize
self.charset = charset

self.caches = []
logger.info("\nReading Cache File... ----> {}".format(self.cache_path))

with open(self.cache_path, 'r', encoding='utf-8') as f:
caches = f.readlines()
self.caches = []
for cache in tqdm.tqdm(caches):
cache = cache.replace("\r", "").replace("\n", "").split("\t")
self.caches.append(cache)
del caches

self.caches = f.readlines()
self.caches_num = len(self.caches)
logger.info("\nRead Cache File End! Caches Num is {}.".format(self.caches_num))

Expand All @@ -42,6 +35,7 @@ def __len__(self):
def __getitem__(self, idx):
try:
data = self.caches[idx]
data = data.replace("\r", "").replace("\n", "").split("\t")
image_name = data[0]
image_label = data[1]
image_path = os.path.join(self.path, image_name)
Expand Down Expand Up @@ -70,7 +64,7 @@ def __getitem__(self, idx):
return image, label

except Exception as e:
logger.error("\nError: {}, File: {}".format(str(e), self.caches[idx][0]))
logger.error("\nError: {}, File: {}".format(str(e), self.caches[idx].split("\t")[0]))
return None, None


Expand Down Expand Up @@ -148,6 +142,8 @@ def __init__(self, project_name: str):
'val': DataLoader(dataset=val_loader, batch_size=self.val_batch_size, shuffle=True, drop_last=True,
num_workers=0, collate_fn=self.collate_to_sparse),
}
del val_loader
del train_loader

def collate_to_sparse(self, batch):
values = []
Expand Down
39 changes: 23 additions & 16 deletions utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, project_name: str):
self.models_path = os.path.join(self.project_path, "models")
self.epoch = 0
self.step = 0
self.lr = None
self.state_dict = None
self.optimizer = None
self.config = Config(project_name)
self.conf = self.config.load_config()

Expand All @@ -36,49 +39,53 @@ def __init__(self, project_name: str):
self.ImageChannel = self.conf['Model']['ImageChannel']
logger.info("\nTaget:\nmin_Accuracy: {}\nmin_Epoch: {}\nmax_Loss: {}".format(self.target_acc, self.min_epoch,
self.max_loss))

logger.info("\nBuilding Net...")
self.net = Net(self.conf)
logger.info(self.net)
logger.info("\nBuilding End")

self.use_gpu = self.conf['System']['GPU']
if self.use_gpu:
self.gpu_id = self.conf['System']['GPU_ID']
logger.info("\nUSE GPU ----> {}".format(self.gpu_id))
self.device = self.net.get_device(self.gpu_id)
self.net.to(self.device)
self.device = Net.get_device(self.gpu_id)

else:
self.gpu_id = -1
self.device = self.net.get_device(self.gpu_id)
self.device = Net.get_device(self.gpu_id)
logger.info("\nUSE CPU".format(self.gpu_id))

logger.info("\nSearch for history checkpoints...")
history_checkpoints = os.listdir(self.checkpoints_path)
if len(history_checkpoints) > 0:
history_step = 0
newer_checkpoint = None
for checkpoint in history_checkpoints:
checkpoint_name = checkpoint.split(".")[0].split("_")
if int(checkpoint_name[2]) > history_step:
if int(checkpoint_name[3]) > history_step:
newer_checkpoint = checkpoint
history_step = int(checkpoint_name[2])
self.epoch, self.step, self.lr = self.net.load_checkpoint(
os.path.join(self.checkpoints_path, newer_checkpoint))
history_step = int(checkpoint_name[3])
param, self.state_dict, self.optimizer= Net.load_checkpoint(
os.path.join(self.checkpoints_path, newer_checkpoint), self.device)
self.epoch, self.step, self.lr = param['epoch'], param['step'], param['lr']
self.epoch += 1
self.step += 1
self.net.lr = self.lr

else:
logger.info("\nEmpty history checkpoints")

logger.info("\nBuilding Net...")
self.net = Net(self.conf, self.lr)
if self.state_dict:
self.net.load_state_dict(self.state_dict)
logger.info(self.net)
logger.info("\nBuilding End")



self.net = self.net.to(self.device)
logger.info("\nGet Data Loader...")

loaders = load_cache.GetLoader(project_name)
self.train = loaders.loaders['train']
self.val = loaders.loaders['val']
del loaders
logger.info("\nGet Data Loader End!")


self.loss = 0
self.avg_loss = 0
self.start_time = time.time()
Expand Down

0 comments on commit 4c92605

Please sign in to comment.