-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathattack.py
386 lines (322 loc) · 14.8 KB
/
attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
import argparse
import json
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
from torch.backends import cudnn
from torch.utils.data import Dataset
import os
import csv
import numpy as np
import wandb
from surrogate import *
from our_dataset import OUR_dataset
from utils import *
from classification import classify
import torch.nn.functional as F
parser = argparse.ArgumentParser(description='Attack')
parser.add_argument('--project', type=str, default="APR")
parser.add_argument('--entity', type=str, default="hashmatshadab")
parser.add_argument('--wandb_mode', type=str, default="disabled")
parser.add_argument('--epsilon', type=float, default=0.1, help="Perturbation budget of the attack")
parser.add_argument('--single_model', type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--chk_pth', type=str, default="trained_models/models/0.pth")
parser.add_argument('--ila_niters', type=int, default=100)
parser.add_argument('--ce_niters', type=int, default=200)
parser.add_argument('--ce_epsilon', type=float, default=0.1)
parser.add_argument('--ce_alpha', type=float, default=1.0)
parser.add_argument('--n_imgs', type=int, default=20)
parser.add_argument('--n_decoders', type=int, default=20, help="Number of decoders in the autoencoder")
parser.add_argument('--ae_dir', type=str, default='./trained_models', help="Path from where to load trained autoencoders")
parser.add_argument('--save_dir', type=str, default='./adv_images', help="Path where adversarial images will be saved")
parser.add_argument('--mode', type=str, default='rotate', help="Mode by which the autoencoders were trained")
parser.add_argument('--ce_method', type=str, default='ifgsm')
parser.add_argument('--start', type=int, default=0)
parser.add_argument('--end', type=int, default=2500)
parser.add_argument('--loss', type=str, default="baseline", choices=["baseline","unsup"])
parser.add_argument('--opl_gamma', type=float, default=0.5)
parser.add_argument('--save_results', type=str, default='results', help="name of file for saving classification scores"
"on various models")
class ILA(torch.nn.Module):
def __init__(self):
super(ILA, self).__init__()
def forward(self, ori_mid, tar_mid, att_mid):
"""
Maximizes projections on mid layer representations
"""
bs = ori_mid.shape[0]
ori_mid = ori_mid.view(bs, -1)
tar_mid = tar_mid.view(bs, -1)
att_mid = att_mid.view(bs, -1)
W = att_mid - ori_mid
V = tar_mid - ori_mid
V = V / V.norm(p=2,dim=1, keepdim=True)
ILA = (W*V).sum() / bs
return ILA
def save_attack_img(img, file_dir):
T.ToPILImage()(img.data.cpu()).save(file_dir)
def initialize_model(decoder_num):
"""
Initialize the auto-encoder model with given number of decoders
:param decoder_num: Number of decoders (20 for prototypical and 1 for other modes)
"""
model = autoencoder(input_nc=3, output_nc=3, n_blocks=3, decoder_num=decoder_num)
model = nn.Sequential(
Normalize(),
model,
)
model.to(device)
return model
def attack_ila(model, ori_img, tar_img, attack_niters, eps):
"""
This function applies ILA attack
:param ori_img: The original input image
:param tar_img: The image after baseline gradient attack
:param attack_niters: Number of ILA iterations
:param eps: Maximum perturbation rate for ILA
:return: returns the image after ILA attack.
"""
# targ_img is the attacked img
model.eval()
ori_img = ori_img.to(device)
img = ori_img.clone()
with torch.no_grad():
# get output of the encoder for tar_img and ori_img without computing gradients
_, tar_h_feats,_ = model(tar_img)
_, ori_h_feats,_ = model(ori_img)
# ori_h_feats are the features from the original image
# tar_h_feats are the features from the attacked images before ila
for i in range(attack_niters):
img.requires_grad_(True)
_, att_h_feats,_ = model(img)
# att_h_feats are features computed after the orig image in the loop
loss = ILA()(ori_h_feats.detach(), tar_h_feats.detach(), att_h_feats)
if (i+1) % 50 == 0:
print('\r ila attacking {}, {:0.4f}'.format(i+1, loss.item()),end=' ')
loss.backward()
input_grad = img.grad.data.sign()
img = img.data + 1. / 255 * input_grad
img = torch.where(img > ori_img + eps, ori_img + eps, img)
img = torch.where(img < ori_img - eps, ori_img - eps, img)
img = torch.clamp(img, min=0, max=1)
print('')
return img.data
def attack_ce_unsup(model, ori_img, attack_niters, eps,args, alpha, n_imgs, ce_method, attack_loss, iter):
"""
For baseline gradient attack (ce_method).
Applied on models trained using rotate/jigsaw/masking approach.
:param model:
:param ori_img: The original input image
:param attack_niters: Number of baseline-gradient attack iterations
:param eps: Maximum perturbation rate for baseline attack
:param alpha: Scaling parameter for adversarial loss
:param n_imgs: Number of images
:param ce_method: The gradient- based baseline attack used (IFGSM or PGD)
:return: Returns the image with adversarial loss maximized within the bound.
"""
model.eval()
ori_img = ori_img.to(device)
nChannels = 3
tar_img = []
if args.loss == "unsup":
for i in range(2 * n_imgs):
tar_img.append(ori_img[i].unsqueeze(0))
tar_img = torch.cat(tar_img, dim=0)
else:
for i in range(n_imgs):
tar_img.append(ori_img[[i, n_imgs + i]])
for i in range(n_imgs):
tar_img.append(ori_img[[n_imgs+i, i]])
tar_img = torch.cat(tar_img, dim=0)
tar_img = tar_img.reshape(2*n_imgs,2,nChannels,224,224)
img = ori_img.clone()
attack_loss[iter] = []
for i in range(attack_niters):
if ce_method == 'ifgsm':
img_x = img
# In our implementation of PGD, we incorporate randomness at each iteration to further enhance the transferability
elif ce_method == 'pgd':
img_x = img + img.new(img.size()).uniform_(-eps, eps)
img_x.requires_grad_(True)
outs, enc_out,_ = model(img_x)
if args.loss == "baseline":
outs = outs[0].unsqueeze(1).repeat(1, 2, 1, 1, 1)
loss_mse_ = nn.MSELoss(reduction='none')(outs, tar_img).sum(dim = (2,3,4)) / (nChannels*224*224)
loss_mse = - alpha * loss_mse_
label = torch.tensor([0]*n_imgs*2).long().to(device)
loss = nn.CrossEntropyLoss()(loss_mse,label)
elif args.loss =="unsup":
outs = outs[0]
loss = nn.MSELoss(reduction='none')(outs, tar_img).sum() / (2*n_imgs*nChannels * 224 * 224)
attack_loss[iter].append(loss.item())
if (i+1) % 50 == 0 or i == 0:
print('\r attacking {}, {:0.4f}'.format(i, loss.item()), end=' ')
loss.backward()
adv_noise = img_x.grad
input_grad = adv_noise.data.sign()
img = img.data + 1. / 255 * input_grad
img = torch.where(img > ori_img + eps, ori_img + eps, img)
img = torch.where(img < ori_img - eps, ori_img - eps, img)
img = torch.clamp(img, min=0, max=1)
print('')
return img.data
def plot_grid(w):
import matplotlib.pyplot as plt
grid_img = torchvision.utils.make_grid(w)
plt.imshow(grid_img.permute(1,2,0).cpu())
plt.show()
def create_json(args):
"""
To create json file to save the arguments (args)
"""
with open(f"{args.save_dir}/config_attack.json", "w") as write_file:
json.dump(args.__dict__, write_file, indent=4)
def attack_ce_proto(model, ori_img, attack_niters,args, eps, alpha, n_decoders, ce_method, n_imgs, prototype_inds,
attack_loss, iter):
"""
For baseline-gradient attack (ce-method) on models trained on prototypical reconstruction approach
:param ori_img: The original input image
:param attack_niters: Number of iterations for ce baseline gradient attack
:param eps: Maximum perturbation rate for ce attack
:param ce_method: The type of gradient-baseline attack (IFGSM or PGD)
:param prototype_inds: The list of prototypes
:param attack_loss: Adversarial loss for prototypical reconstruction
:return: Returns the image with added perturbation (within the eps bound) maximising adversarial loss
"""
model.eval()
ori_img = ori_img.to(device)
tar_img = []
for i in range(n_decoders):
# get one prototype pair for each decoder, just like in training
tar_img.append(ori_img[[prototype_inds[2*i],prototype_inds[2*i+1]]])
tar_img = torch.cat(tar_img, dim = 0)
nChannels = 3
if n_decoders == 1:
decoder_size = 224
else:
decoder_size = 56
tar_img = F.interpolate(tar_img, size=(56,56)) # [40, 3, 56, 56]
tar_img = tar_img.reshape(n_decoders,2,nChannels,decoder_size,decoder_size).unsqueeze(1) # [20, 1, 2, 3, 56, 56]
# the 40 images are grouped as 20 pairs, each pair has image from the two different classes
tar_img = tar_img.repeat(1,n_imgs*2,1,1,1,1).reshape(n_imgs*2*n_decoders,2,nChannels,decoder_size,decoder_size) # [400, 2, 3, 56, 56]
# 400 pairs, each of the 20 pairs is repeated 20 times. i.e first 20 pairs are same and so on
img = ori_img.clone()
attack_loss[iter] = []
for i in range(attack_niters):
if ce_method == 'ifgsm':
img_x = img
elif ce_method == 'pgd':
img_x = img + img.new(img.size()).uniform_(-eps, eps)
img_x.requires_grad_(True)
outs, enc_out,_ = model(img_x)
outs = torch.cat(outs, dim=0).unsqueeze(1).repeat(1, 2, 1, 1, 1) # [400, 2, 3, 56, 56]
loss_mse_ = nn.MSELoss(reduction='none')(outs, tar_img).sum(dim=(2, 3, 4)) / (
nChannels * decoder_size * decoder_size) # [400, 2]
loss_mse = - alpha * loss_mse_
label = torch.tensor(([0] * n_imgs + [1] * n_imgs) * n_decoders).long().to(device)
loss = nn.CrossEntropyLoss()(loss_mse, label)
# will give 20 images for each input image >> total [400, 3, 56, 56]
if (i+1) % 50 == 0 or i == 0:
print('attacking {}, {:0.4f}'.format(i, loss.item()))
attack_loss[iter].append(loss.item())
loss.backward()
adv_noise = img_x.grad
input_grad = adv_noise.data.sign()
img = img.data + 1. / 255 * input_grad
img = torch.where(img > ori_img + eps, ori_img + eps, img)
img = torch.where(img < ori_img - eps, ori_img - eps, img)
img = torch.clamp(img, min=0, max=1)
print('')
return img.data
if __name__ == '__main__':
args = parser.parse_args()
wandb.init(project=args.project, entity=args.entity, mode=args.wandb_mode, name=args.save_dir.split("/")[-1])
SEED = 0
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
print(args)
config = wandb.config
config.update(args)
mode = args.mode
save_dir = args.save_dir
n_imgs = args.n_imgs // 2
if mode != 'prototypical':
n_decoders = 1
else:
n_decoders = args.n_decoders
assert n_decoders <= n_imgs ** 2, 'Too many decoders.'
os.makedirs(save_dir, exist_ok=True)
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
batch_size = n_imgs * 2
epsilon = args.epsilon
ce_epsilon = args.ce_epsilon
ila_niters = args.ila_niters
ce_niters = args.ce_niters
ce_alpha = args.ce_alpha
ae_dir = args.ae_dir
ce_method = args.ce_method
assert ce_method in ['ifgsm', 'pgd']
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
trans = T.Compose([
T.Resize((256, 256)),
T.CenterCrop(224),
T.ToTensor()
])
dataset = OUR_dataset(data_dir='data/ILSVRC2012_img_val',
data_csv_dir='data/selected_data.csv',
mode='attack',
img_num=n_imgs,
transform=trans)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = False, num_workers = 1)
create_json(args)
fig, ax = plt.subplots()
if args.single_model:
args.loss = 'unsup'
model = initialize_model(decoder_num=n_decoders)
model.load_state_dict(torch.load(args.chk_pth))
model.eval()
for data_ind, (ori_img, _) in enumerate(dataloader):
if not args.start <= data_ind < args.end:
continue
if not args.single_model:
model = initialize_model(n_decoders)
model.load_state_dict(torch.load('{}/models/{}_{}.pth'.format(ae_dir, args.mode, data_ind)))
model.eval()
ori_img = ori_img.to(device)
attack_loss = {}
if mode =='prototypical':
prototype_ind_csv = open(ae_dir+'/prototype_ind.csv', 'r')
prototype_ind_ls = list(csv.reader(prototype_ind_csv))
old_att_img = attack_ce_proto(model, ori_img,args=args, attack_niters = ce_niters,
eps = ce_epsilon, alpha=ce_alpha, n_decoders = n_decoders,
ce_method = ce_method, n_imgs = n_imgs,
prototype_inds = list(map(int,prototype_ind_ls[data_ind])),
attack_loss=attack_loss, iter=data_ind) #**
else:
old_att_img = attack_ce_unsup(model, ori_img,args=args, attack_niters=ce_niters,
eps=ce_epsilon, alpha=ce_alpha, n_imgs=n_imgs,
ce_method=ce_method,
attack_loss=attack_loss, iter=data_ind)
xs = [x for x in range(len(attack_loss[data_ind]))]
ax.plot(xs, attack_loss[data_ind], label=f"Model_{data_ind}")
ax.set_xlabel("Iterations")
ax.set_ylabel("Loss")
ax.set_title(f"Model_{mode}_{data_ind}")
wandb.log({f'plot': ax})
att_img = attack_ila(model, ori_img, old_att_img, ila_niters, eps=epsilon)
for save_ind in range(batch_size):
file_path, file_name = dataset.imgs[data_ind * 2*n_imgs + save_ind][0].split('/')[-2:]
os.makedirs(save_dir + '/' + file_path, exist_ok=True)
save_attack_img(img=att_img[save_ind],
file_dir=os.path.join(save_dir, file_path, file_name[:-5]) + '.png')
print('\r', data_ind * batch_size + save_ind, 'images saved.', end=' ')
classify(save_dir=save_dir, batch_size=batch_size, save_results=args.save_results)