|
| 1 | +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models |
| 2 | +
|
| 3 | +Taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/lpips.py#L11 |
| 4 | +""" |
| 5 | + |
| 6 | +import hashlib |
| 7 | +import os |
| 8 | +from collections import namedtuple |
| 9 | + |
| 10 | +import requests |
| 11 | +import torch |
| 12 | +import torch.nn as nn |
| 13 | +from torchvision import models |
| 14 | +from tqdm import tqdm |
| 15 | + |
| 16 | +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} |
| 17 | + |
| 18 | +CKPT_MAP = {"vgg_lpips": "vgg.pth"} |
| 19 | + |
| 20 | +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} |
| 21 | + |
| 22 | + |
| 23 | +def download(url, local_path, chunk_size=1024): |
| 24 | + os.makedirs(os.path.split(local_path)[0], exist_ok=True) |
| 25 | + with requests.get(url, stream=True) as r: |
| 26 | + total_size = int(r.headers.get("content-length", 0)) |
| 27 | + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: |
| 28 | + with open(local_path, "wb") as f: |
| 29 | + for data in r.iter_content(chunk_size=chunk_size): |
| 30 | + if data: |
| 31 | + f.write(data) |
| 32 | + pbar.update(chunk_size) |
| 33 | + |
| 34 | + |
| 35 | +def md5_hash(path): |
| 36 | + with open(path, "rb") as f: |
| 37 | + content = f.read() |
| 38 | + return hashlib.md5(content).hexdigest() |
| 39 | + |
| 40 | + |
| 41 | +def get_ckpt_path(name, root, check=False): |
| 42 | + assert name in URL_MAP |
| 43 | + path = os.path.join(root, CKPT_MAP[name]) |
| 44 | + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): |
| 45 | + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) |
| 46 | + download(URL_MAP[name], path) |
| 47 | + md5 = md5_hash(path) |
| 48 | + assert md5 == MD5_MAP[name], md5 |
| 49 | + return path |
| 50 | + |
| 51 | + |
| 52 | +class KeyNotFoundError(Exception): |
| 53 | + def __init__(self, cause, keys=None, visited=None): |
| 54 | + self.cause = cause |
| 55 | + self.keys = keys |
| 56 | + self.visited = visited |
| 57 | + messages = list() |
| 58 | + if keys is not None: |
| 59 | + messages.append("Key not found: {}".format(keys)) |
| 60 | + if visited is not None: |
| 61 | + messages.append("Visited: {}".format(visited)) |
| 62 | + messages.append("Cause:\n{}".format(cause)) |
| 63 | + message = "\n".join(messages) |
| 64 | + super().__init__(message) |
| 65 | + |
| 66 | + |
| 67 | +def retrieve( |
| 68 | + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False |
| 69 | +): |
| 70 | + """Given a nested list or dict return the desired value at key expanding |
| 71 | + callable nodes if necessary and :attr:`expand` is ``True``. The expansion |
| 72 | + is done in-place. |
| 73 | +
|
| 74 | + Parameters |
| 75 | + ---------- |
| 76 | + list_or_dict : list or dict |
| 77 | + Possibly nested list or dictionary. |
| 78 | + key : str |
| 79 | + key/to/value, path like string describing all keys necessary to |
| 80 | + consider to get to the desired value. List indices can also be |
| 81 | + passed here. |
| 82 | + splitval : str |
| 83 | + String that defines the delimiter between keys of the |
| 84 | + different depth levels in `key`. |
| 85 | + default : obj |
| 86 | + Value returned if :attr:`key` is not found. |
| 87 | + expand : bool |
| 88 | + Whether to expand callable nodes on the path or not. |
| 89 | +
|
| 90 | + Returns |
| 91 | + ------- |
| 92 | + The desired value or if :attr:`default` is not ``None`` and the |
| 93 | + :attr:`key` is not found returns ``default``. |
| 94 | +
|
| 95 | + Raises |
| 96 | + ------ |
| 97 | + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is |
| 98 | + ``None``. |
| 99 | + """ |
| 100 | + |
| 101 | + keys = key.split(splitval) |
| 102 | + |
| 103 | + success = True |
| 104 | + try: |
| 105 | + visited = [] |
| 106 | + parent = None |
| 107 | + last_key = None |
| 108 | + for key in keys: |
| 109 | + if callable(list_or_dict): |
| 110 | + if not expand: |
| 111 | + raise KeyNotFoundError( |
| 112 | + ValueError( |
| 113 | + "Trying to get past callable node with expand=False." |
| 114 | + ), |
| 115 | + keys=keys, |
| 116 | + visited=visited, |
| 117 | + ) |
| 118 | + list_or_dict = list_or_dict() |
| 119 | + parent[last_key] = list_or_dict |
| 120 | + |
| 121 | + last_key = key |
| 122 | + parent = list_or_dict |
| 123 | + |
| 124 | + try: |
| 125 | + if isinstance(list_or_dict, dict): |
| 126 | + list_or_dict = list_or_dict[key] |
| 127 | + else: |
| 128 | + list_or_dict = list_or_dict[int(key)] |
| 129 | + except (KeyError, IndexError, ValueError) as e: |
| 130 | + raise KeyNotFoundError(e, keys=keys, visited=visited) |
| 131 | + |
| 132 | + visited += [key] |
| 133 | + # final expansion of retrieved value |
| 134 | + if expand and callable(list_or_dict): |
| 135 | + list_or_dict = list_or_dict() |
| 136 | + parent[last_key] = list_or_dict |
| 137 | + except KeyNotFoundError as e: |
| 138 | + if default is None: |
| 139 | + raise e |
| 140 | + else: |
| 141 | + list_or_dict = default |
| 142 | + success = False |
| 143 | + |
| 144 | + if not pass_success: |
| 145 | + return list_or_dict |
| 146 | + else: |
| 147 | + return list_or_dict, success |
| 148 | + |
| 149 | + |
| 150 | +class LPIPS(nn.Module): |
| 151 | + # Learned perceptual metric |
| 152 | + def __init__(self, use_dropout=True): |
| 153 | + super().__init__() |
| 154 | + self.scaling_layer = ScalingLayer() |
| 155 | + self.chns = [64, 128, 256, 512, 512] # vg16 features |
| 156 | + self.net = vgg16(pretrained=True, requires_grad=False) |
| 157 | + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
| 158 | + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
| 159 | + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
| 160 | + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
| 161 | + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
| 162 | + self.load_from_pretrained() |
| 163 | + for param in self.parameters(): |
| 164 | + param.requires_grad = False |
| 165 | + |
| 166 | + def load_from_pretrained(self, name="vgg_lpips"): |
| 167 | + ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") |
| 168 | + self.load_state_dict( |
| 169 | + torch.load(ckpt, map_location=torch.device("cpu")), strict=False |
| 170 | + ) |
| 171 | + print("loaded pretrained LPIPS loss from {}".format(ckpt)) |
| 172 | + |
| 173 | + @classmethod |
| 174 | + def from_pretrained(cls, name="vgg_lpips"): |
| 175 | + if name != "vgg_lpips": |
| 176 | + raise NotImplementedError |
| 177 | + model = cls() |
| 178 | + ckpt = get_ckpt_path(name) |
| 179 | + model.load_state_dict( |
| 180 | + torch.load(ckpt, map_location=torch.device("cpu")), strict=False |
| 181 | + ) |
| 182 | + return model |
| 183 | + |
| 184 | + def forward(self, input, target): |
| 185 | + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) |
| 186 | + outs0, outs1 = self.net(in0_input), self.net(in1_input) |
| 187 | + feats0, feats1, diffs = {}, {}, {} |
| 188 | + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] |
| 189 | + for kk in range(len(self.chns)): |
| 190 | + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( |
| 191 | + outs1[kk] |
| 192 | + ) |
| 193 | + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 |
| 194 | + |
| 195 | + res = [ |
| 196 | + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) |
| 197 | + for kk in range(len(self.chns)) |
| 198 | + ] |
| 199 | + val = res[0] |
| 200 | + for l in range(1, len(self.chns)): |
| 201 | + val += res[l] |
| 202 | + return val |
| 203 | + |
| 204 | + |
| 205 | +class ScalingLayer(nn.Module): |
| 206 | + def __init__(self): |
| 207 | + super(ScalingLayer, self).__init__() |
| 208 | + self.register_buffer( |
| 209 | + "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] |
| 210 | + ) |
| 211 | + self.register_buffer( |
| 212 | + "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] |
| 213 | + ) |
| 214 | + |
| 215 | + def forward(self, inp): |
| 216 | + return (inp - self.shift) / self.scale |
| 217 | + |
| 218 | + |
| 219 | +class NetLinLayer(nn.Module): |
| 220 | + """A single linear layer which does a 1x1 conv""" |
| 221 | + |
| 222 | + def __init__(self, chn_in, chn_out=1, use_dropout=False): |
| 223 | + super(NetLinLayer, self).__init__() |
| 224 | + layers = ( |
| 225 | + [ |
| 226 | + nn.Dropout(), |
| 227 | + ] |
| 228 | + if (use_dropout) |
| 229 | + else [] |
| 230 | + ) |
| 231 | + layers += [ |
| 232 | + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), |
| 233 | + ] |
| 234 | + self.model = nn.Sequential(*layers) |
| 235 | + |
| 236 | + |
| 237 | +class vgg16(torch.nn.Module): |
| 238 | + def __init__(self, requires_grad=False, pretrained=True): |
| 239 | + super(vgg16, self).__init__() |
| 240 | + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features |
| 241 | + self.slice1 = torch.nn.Sequential() |
| 242 | + self.slice2 = torch.nn.Sequential() |
| 243 | + self.slice3 = torch.nn.Sequential() |
| 244 | + self.slice4 = torch.nn.Sequential() |
| 245 | + self.slice5 = torch.nn.Sequential() |
| 246 | + self.N_slices = 5 |
| 247 | + for x in range(4): |
| 248 | + self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
| 249 | + for x in range(4, 9): |
| 250 | + self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
| 251 | + for x in range(9, 16): |
| 252 | + self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
| 253 | + for x in range(16, 23): |
| 254 | + self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
| 255 | + for x in range(23, 30): |
| 256 | + self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
| 257 | + if not requires_grad: |
| 258 | + for param in self.parameters(): |
| 259 | + param.requires_grad = False |
| 260 | + |
| 261 | + def forward(self, X): |
| 262 | + h = self.slice1(X) |
| 263 | + h_relu1_2 = h |
| 264 | + h = self.slice2(h) |
| 265 | + h_relu2_2 = h |
| 266 | + h = self.slice3(h) |
| 267 | + h_relu3_3 = h |
| 268 | + h = self.slice4(h) |
| 269 | + h_relu4_3 = h |
| 270 | + h = self.slice5(h) |
| 271 | + h_relu5_3 = h |
| 272 | + vgg_outputs = namedtuple( |
| 273 | + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] |
| 274 | + ) |
| 275 | + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) |
| 276 | + return out |
| 277 | + |
| 278 | + |
| 279 | +def normalize_tensor(x, eps=1e-10): |
| 280 | + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) |
| 281 | + return x / (norm_factor + eps) |
| 282 | + |
| 283 | + |
| 284 | +def spatial_average(x, keepdim=True): |
| 285 | + return x.mean([2, 3], keepdim=keepdim) |
0 commit comments