-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvis_util.py
32 lines (27 loc) · 1.05 KB
/
vis_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import torch
from torchvision.utils import save_image
from einops import rearrange
def get_id(img_att, num_prototype):
_, index = torch.topk(img_att, k=1, dim=1)
index = index.squeeze()
prototype_id = []
for i in range(num_prototype):
prototype_id.append(torch.where(index[0] == i)[0])
return prototype_id
def mask_image(x, idx, patch_size):
"""
Args:
x: input image, shape: [B, 3, H, W]
idx: indices of masks, shape: [B, T], value in range [0, h*w)
Return:
out_img: masked image with only patches from idx postions
"""
h = x.size(2) // patch_size
x = rearrange(x, 'b c (h p) (w q) -> b (c p q) (h w)', p=patch_size, q=patch_size)
output = torch.zeros_like(x)
idx1 = idx.unsqueeze(1).expand(-1, x.size(1), -1)
extracted = torch.gather(x, dim=2, index=idx1) # [b, c p q, T]
scattered = torch.scatter(output, dim=2, index=idx1, src=extracted)
out_img = rearrange(scattered, 'b (c p q) (h w) -> b c (h p) (w q)', p=patch_size, q=patch_size, h=h)
return out_img