This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathutils_img.py
175 lines (149 loc) · 6.25 KB
/
utils_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import functional
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)
default_transform = transforms.Compose([
transforms.ToTensor(),
NORMALIZE_IMAGENET,
])
def normalize_img(x):
""" Normalize image to approx. [-1,1] """
return (x.to(device) - image_mean) / image_std
def unnormalize_img(x):
""" Unnormalize image to [0,1] """
return (x.to(device) * image_std) + image_mean
def round_pixel(x):
"""
Round pixel values to nearest integer.
Args:
x: Image tensor with values approx. between [-1,1]
Returns:
y: Rounded image tensor with values approx. between [-1,1]
"""
x_pixel = 255 * unnormalize_img(x)
y = torch.round(x_pixel).clamp(0, 255)
y = normalize_img(y/255.0)
return y
def project_linf(x, y, radius):
"""
Clamp x so that Linf(x,y)<=radius
Args:
x: Image tensor with values approx. between [-1,1]
y: Image tensor with values approx. between [-1,1], ex: original image
radius: Radius of Linf ball for the images in pixel space [0, 255]
"""
delta = x - y
delta = 255 * (delta * image_std)
delta = torch.clamp(delta, -radius, radius)
delta = (delta / 255.0) / image_std
return y + delta
def psnr_clip(x, y, target_psnr):
"""
Clip x so that PSNR(x,y)=target_psnr
Args:
x: Image tensor with values approx. between [-1,1]
y: Image tensor with values approx. between [-1,1], ex: original image
target_psnr: Target PSNR value in dB
"""
delta = x - y
delta = 255 * (delta * image_std)
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
if psnr<target_psnr:
delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
delta = (delta / 255.0) / image_std
return y + delta
class SSIMAttenuation:
def __init__(self, window_size=17, sigma=1.5, device="cpu"):
""" Self-similarity attenuation, to make sure that the augmentations occur high-detail zones. """
self.window_size = window_size
_1D_window = torch.Tensor(
[np.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]
).to(device, non_blocking=True)
_1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
self.window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())
def heatmap(self, img1, img2):
"""
Compute the SSIM heatmap between 2 images, based upon https://github.com/Po-Hsun-Su/pytorch-ssim
Args:
img1: Image tensor with values approx. between [-1,1]
img2: Image tensor with values approx. between [-1,1]
window_size: Size of the window for the SSIM computation
"""
window = self.window
window_size = self.window_size
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
return ssim_map
def apply(self, x, y):
"""
Attenuate x using SSIM heatmap to concentrate changes of y wrt. x around edges
Args:
x: Image tensor with values approx. between [-1,1]
y: Image tensor with values approx. between [-1,1], ex: original image
"""
delta = x - y
ssim_map = self.heatmap(x, y) # 1xCxHxW
ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
ssim_map = torch.clamp_min(ssim_map,0)
delta = delta*ssim_map
return y + delta
def center_crop(x, scale):
""" Perform center crop such that the target area of the crop is at a given scale
Args:
x: PIL image
scale: target area scale
"""
scale = np.sqrt(scale)
new_edges_size = [int(s*scale) for s in x.size][::-1]
return functional.center_crop(x, new_edges_size)
def resize(x, scale):
""" Perform center crop such that the target area of the crop is at a given scale
Args:
x: PIL image
scale: target area scale
"""
scale = np.sqrt(scale)
new_edges_size = [int(s*scale) for s in x.size][::-1]
return functional.resize(x, new_edges_size)
def get_dataloader(data_dir, transform=default_transform, batch_size=128, shuffle=False, num_workers=4):
""" Get dataloader for the images in the data_dir. The data_dir must be of the form: input/0/... """
dataset = datasets.ImageFolder(data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return dataloader
def pil_imgs_from_folder(folder):
""" Get all images in the folder as PIL images """
images = []
filenames = []
for filename in os.listdir(folder):
try:
img = Image.open(os.path.join(folder,filename))
if img is not None:
filenames.append(filename)
images.append(img)
except:
print("Error opening image: ", filename)
return images, filenames