diff --git a/data/dataloader.py b/data/dataloader.py index 323c438..6048308 100644 --- a/data/dataloader.py +++ b/data/dataloader.py @@ -15,19 +15,7 @@ from PIL import Image import torchvision.transforms.functional as TF import random - - -def check_data(data_folder): - masks = set(os.listdir(f'{data_folder}/masks/')) - image = set(os.listdir(f'{data_folder}/images/')) - - intersection = masks.intersection(image) - union = masks.union(image) - print(f"[!] {len(union) - len(intersection)} of {len(union)} images doesn't have mask") - - intersection = list(intersection) - - return intersection +from data.utils import check_data def transform(image, mask, image_size=224): diff --git a/data/test_loader.py b/data/test_loader.py index 000f6fb..c814cc1 100644 --- a/data/test_loader.py +++ b/data/test_loader.py @@ -4,7 +4,7 @@ import torchvision.transforms as transforms from PIL import Image import torchvision.transforms.functional as TF -from random import random +from data.utils import check_data def transform(image, mask, image_size=224): @@ -12,28 +12,14 @@ def transform(image, mask, image_size=224): image = resize(image) mask = resize(mask) - # if random() > 0.5: - # image = TF.vflip(image) - # mask = TF.vflip(mask) - - # if random() > 0.5: - # image = TF.hflip(image) - # mask = TF.hflip(mask) - - # angle = random() * 12 - 6 - # image = TF.rotate(image, angle) - # mask = TF.rotate(mask, angle) - - # pad_size = random() * image_size - # image = TF.pad(image, pad_size, padding_mode='edge') - # mask = TF.pad(mask, pad_size, padding_mode='edge') + mask = TF.to_grayscale(mask) # Transform to tensor image = TF.to_tensor(image) mask = TF.to_tensor(mask) # Normalize Data - image = TF.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + image = TF.normalize(image, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) return image, mask @@ -45,7 +31,7 @@ def __init__(self, data_folder, image_size): raise Exception(f"[!] {self.data_folder} not exists.") self.objects_path = [] - self.image_name = os.listdir(os.path.join(data_folder, "images")) + self.image_name = check_data(data_folder) if len(self.image_name) == 0: raise Exception(f"No image found in {self.image_name}") for p in os.listdir(data_folder): diff --git a/data/utils.py b/data/utils.py new file mode 100644 index 0000000..a8af748 --- /dev/null +++ b/data/utils.py @@ -0,0 +1,14 @@ +import os + + +def check_data(data_folder): + masks = set(os.listdir(f'{data_folder}/masks/')) + image = set(os.listdir(f'{data_folder}/images/')) + + intersection = masks.intersection(image) + union = masks.union(image) + print(f"[!] {len(union) - len(intersection)} of {len(union)} images doesn't have mask") + + intersection = list(intersection) + + return intersection \ No newline at end of file diff --git a/main.py b/main.py index bd96efb..173e396 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,8 @@ def main(config): if torch.cuda.is_available(): torch.cuda.manual_seed_all(config.manual_seed) - torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False if not config.test: trainer = Trainer(config) diff --git a/param/best.pt b/param/best.pt index f88e85d..3b00444 100644 Binary files a/param/best.pt and b/param/best.pt differ diff --git a/param/last.pt b/param/last.pt index 68984f4..225bb18 100644 Binary files a/param/last.pt and b/param/last.pt differ diff --git a/param/quantized.pt b/param/quantized.pt index 3d8ffcd..aa8bcba 100644 Binary files a/param/quantized.pt and b/param/quantized.pt differ diff --git a/src/test.py b/src/test.py index 761b5ba..f1a0d29 100644 --- a/src/test.py +++ b/src/test.py @@ -37,27 +37,28 @@ def load_model(self): def test(self, net=None): if net: self.net = net - avg_meter = AverageMeter() - unnormal = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - pbar = tqdm(enumerate(self.data_loader), total=len(self.data_loader)) - for step, (image, mask) in pbar: - image = image.to(self.device) - #image = unnormal(image.to(self.device)) - result = self.net(image) + avg_meter = AverageMeter() + with torch.no_grad(): + unnormal = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + pbar = tqdm(enumerate(self.data_loader), total=len(self.data_loader)) + for step, (image, mask) in pbar: + image = image.to(self.device) + #image = unnormal(image.to(self.device)) + result = self.net(image) - mask = mask.to(self.device) + mask = mask.to(self.device) - avg_meter.update(iou_loss(result, mask)) - pbar.set_description(f'IOU: {avg_meter.avg:.4f}') + avg_meter.update(iou_loss(result, mask)) + pbar.set_description(f'IOU: {avg_meter.avg:.4f}') - mask = mask.repeat_interleave(3, 1) - argmax = torch.argmax(result, dim=1).unsqueeze(dim=1) - result = result[:, 1, :, :].unsqueeze(dim=1) - result = result * argmax - result = result.repeat_interleave(3, 1) - torch.cat([image, result, mask]) + mask = mask.repeat_interleave(3, 1) + argmax = torch.argmax(result, dim=1).unsqueeze(dim=1) + result = result[:, 1, :, :].unsqueeze(dim=1) + result = result * argmax + result = result.repeat_interleave(3, 1) + torch.cat([image, result, mask]) - save_image(torch.cat([image, result, mask]), os.path.join(self.sample_dir, f"{step}.png")) + save_image(torch.cat([image, result, mask]), os.path.join(self.sample_dir, f"{step}.png")) diff --git a/src/train.py b/src/train.py index ee35b4b..75d8842 100644 --- a/src/train.py +++ b/src/train.py @@ -131,7 +131,7 @@ def quantize_model(self): return print('Load Best Model') - ckpt = f'{self.checkpoint_dir}/best.pt' + ckpt = f'{self.checkpoint_dir}/mobilenetv2.pt' save_info = torch.load(ckpt, map_location=self.device) self.net = save_info['model'] self.net.load_state_dict(save_info['state_dict']) @@ -154,6 +154,9 @@ def quantize_model(self): temp = self.num_epoch self.num_epoch = self.num_quantize_train + self.lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr, epochs=self.num_epoch, + steps_per_epoch=self.image_len, cycle_momentum=False) + for i in range(self.num_quantize_train): self._train_one_epoch(i, image_gradient_criterion, bce_criterion, quantize=True) self.num_epoch = temp @@ -216,7 +219,6 @@ def _train_one_epoch(self, epoch, image_gradient_criterion, bce_criterion, quant 'lr_scheduler': self.lr_scheduler.state_dict(), 'run_id': self.run.id} torch.save(save_info, f'{self.checkpoint_dir}/last.pt') wandb.save(f'{self.checkpoint_dir}/last.pt', './') - print(f"[*] Save model Epoch: {epoch}") if image is not None: img = torch.cat( [image[0], mask[0].repeat(3, 1, 1), pred[0].argmax(dim=0).unsqueeze(dim=0).repeat(3, 1, 1)], dim=2)