-
Notifications
You must be signed in to change notification settings - Fork 382
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
52,599 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models | ||
Taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/lpips.py#L11 | ||
""" | ||
|
||
import hashlib | ||
import os | ||
from collections import namedtuple | ||
|
||
import requests | ||
import torch | ||
import torch.nn as nn | ||
from torchvision import models | ||
from tqdm import tqdm | ||
|
||
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} | ||
|
||
CKPT_MAP = {"vgg_lpips": "vgg.pth"} | ||
|
||
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} | ||
|
||
|
||
def download(url, local_path, chunk_size=1024): | ||
os.makedirs(os.path.split(local_path)[0], exist_ok=True) | ||
with requests.get(url, stream=True) as r: | ||
total_size = int(r.headers.get("content-length", 0)) | ||
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: | ||
with open(local_path, "wb") as f: | ||
for data in r.iter_content(chunk_size=chunk_size): | ||
if data: | ||
f.write(data) | ||
pbar.update(chunk_size) | ||
|
||
|
||
def md5_hash(path): | ||
with open(path, "rb") as f: | ||
content = f.read() | ||
return hashlib.md5(content).hexdigest() | ||
|
||
|
||
def get_ckpt_path(name, root, check=False): | ||
assert name in URL_MAP | ||
path = os.path.join(root, CKPT_MAP[name]) | ||
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): | ||
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) | ||
download(URL_MAP[name], path) | ||
md5 = md5_hash(path) | ||
assert md5 == MD5_MAP[name], md5 | ||
return path | ||
|
||
|
||
class KeyNotFoundError(Exception): | ||
def __init__(self, cause, keys=None, visited=None): | ||
self.cause = cause | ||
self.keys = keys | ||
self.visited = visited | ||
messages = list() | ||
if keys is not None: | ||
messages.append("Key not found: {}".format(keys)) | ||
if visited is not None: | ||
messages.append("Visited: {}".format(visited)) | ||
messages.append("Cause:\n{}".format(cause)) | ||
message = "\n".join(messages) | ||
super().__init__(message) | ||
|
||
|
||
def retrieve( | ||
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False | ||
): | ||
"""Given a nested list or dict return the desired value at key expanding | ||
callable nodes if necessary and :attr:`expand` is ``True``. The expansion | ||
is done in-place. | ||
Parameters | ||
---------- | ||
list_or_dict : list or dict | ||
Possibly nested list or dictionary. | ||
key : str | ||
key/to/value, path like string describing all keys necessary to | ||
consider to get to the desired value. List indices can also be | ||
passed here. | ||
splitval : str | ||
String that defines the delimiter between keys of the | ||
different depth levels in `key`. | ||
default : obj | ||
Value returned if :attr:`key` is not found. | ||
expand : bool | ||
Whether to expand callable nodes on the path or not. | ||
Returns | ||
------- | ||
The desired value or if :attr:`default` is not ``None`` and the | ||
:attr:`key` is not found returns ``default``. | ||
Raises | ||
------ | ||
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is | ||
``None``. | ||
""" | ||
|
||
keys = key.split(splitval) | ||
|
||
success = True | ||
try: | ||
visited = [] | ||
parent = None | ||
last_key = None | ||
for key in keys: | ||
if callable(list_or_dict): | ||
if not expand: | ||
raise KeyNotFoundError( | ||
ValueError( | ||
"Trying to get past callable node with expand=False." | ||
), | ||
keys=keys, | ||
visited=visited, | ||
) | ||
list_or_dict = list_or_dict() | ||
parent[last_key] = list_or_dict | ||
|
||
last_key = key | ||
parent = list_or_dict | ||
|
||
try: | ||
if isinstance(list_or_dict, dict): | ||
list_or_dict = list_or_dict[key] | ||
else: | ||
list_or_dict = list_or_dict[int(key)] | ||
except (KeyError, IndexError, ValueError) as e: | ||
raise KeyNotFoundError(e, keys=keys, visited=visited) | ||
|
||
visited += [key] | ||
# final expansion of retrieved value | ||
if expand and callable(list_or_dict): | ||
list_or_dict = list_or_dict() | ||
parent[last_key] = list_or_dict | ||
except KeyNotFoundError as e: | ||
if default is None: | ||
raise e | ||
else: | ||
list_or_dict = default | ||
success = False | ||
|
||
if not pass_success: | ||
return list_or_dict | ||
else: | ||
return list_or_dict, success | ||
|
||
|
||
class LPIPS(nn.Module): | ||
# Learned perceptual metric | ||
def __init__(self, use_dropout=True): | ||
super().__init__() | ||
self.scaling_layer = ScalingLayer() | ||
self.chns = [64, 128, 256, 512, 512] # vg16 features | ||
self.net = vgg16(pretrained=True, requires_grad=False) | ||
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) | ||
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) | ||
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) | ||
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) | ||
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) | ||
self.load_from_pretrained() | ||
for param in self.parameters(): | ||
param.requires_grad = False | ||
|
||
def load_from_pretrained(self, name="vgg_lpips"): | ||
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") | ||
self.load_state_dict( | ||
torch.load(ckpt, map_location=torch.device("cpu")), strict=False | ||
) | ||
print("loaded pretrained LPIPS loss from {}".format(ckpt)) | ||
|
||
@classmethod | ||
def from_pretrained(cls, name="vgg_lpips"): | ||
if name != "vgg_lpips": | ||
raise NotImplementedError | ||
model = cls() | ||
ckpt = get_ckpt_path(name) | ||
model.load_state_dict( | ||
torch.load(ckpt, map_location=torch.device("cpu")), strict=False | ||
) | ||
return model | ||
|
||
def forward(self, input, target): | ||
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) | ||
outs0, outs1 = self.net(in0_input), self.net(in1_input) | ||
feats0, feats1, diffs = {}, {}, {} | ||
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] | ||
for kk in range(len(self.chns)): | ||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( | ||
outs1[kk] | ||
) | ||
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 | ||
|
||
res = [ | ||
spatial_average(lins[kk].model(diffs[kk]), keepdim=True) | ||
for kk in range(len(self.chns)) | ||
] | ||
val = res[0] | ||
for l in range(1, len(self.chns)): | ||
val += res[l] | ||
return val | ||
|
||
|
||
class ScalingLayer(nn.Module): | ||
def __init__(self): | ||
super(ScalingLayer, self).__init__() | ||
self.register_buffer( | ||
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] | ||
) | ||
self.register_buffer( | ||
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] | ||
) | ||
|
||
def forward(self, inp): | ||
return (inp - self.shift) / self.scale | ||
|
||
|
||
class NetLinLayer(nn.Module): | ||
"""A single linear layer which does a 1x1 conv""" | ||
|
||
def __init__(self, chn_in, chn_out=1, use_dropout=False): | ||
super(NetLinLayer, self).__init__() | ||
layers = ( | ||
[ | ||
nn.Dropout(), | ||
] | ||
if (use_dropout) | ||
else [] | ||
) | ||
layers += [ | ||
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), | ||
] | ||
self.model = nn.Sequential(*layers) | ||
|
||
|
||
class vgg16(torch.nn.Module): | ||
def __init__(self, requires_grad=False, pretrained=True): | ||
super(vgg16, self).__init__() | ||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features | ||
self.slice1 = torch.nn.Sequential() | ||
self.slice2 = torch.nn.Sequential() | ||
self.slice3 = torch.nn.Sequential() | ||
self.slice4 = torch.nn.Sequential() | ||
self.slice5 = torch.nn.Sequential() | ||
self.N_slices = 5 | ||
for x in range(4): | ||
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | ||
for x in range(4, 9): | ||
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | ||
for x in range(9, 16): | ||
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | ||
for x in range(16, 23): | ||
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | ||
for x in range(23, 30): | ||
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | ||
if not requires_grad: | ||
for param in self.parameters(): | ||
param.requires_grad = False | ||
|
||
def forward(self, X): | ||
h = self.slice1(X) | ||
h_relu1_2 = h | ||
h = self.slice2(h) | ||
h_relu2_2 = h | ||
h = self.slice3(h) | ||
h_relu3_3 = h | ||
h = self.slice4(h) | ||
h_relu4_3 = h | ||
h = self.slice5(h) | ||
h_relu5_3 = h | ||
vgg_outputs = namedtuple( | ||
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] | ||
) | ||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) | ||
return out | ||
|
||
|
||
def normalize_tensor(x, eps=1e-10): | ||
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) | ||
return x / (norm_factor + eps) | ||
|
||
|
||
def spatial_average(x, keepdim=True): | ||
return x.mean([2, 3], keepdim=keepdim) |
Oops, something went wrong.