Skip to content

Commit fa74e50

Browse files
committed
hw3_helper update
1 parent 2369f60 commit fa74e50

File tree

4 files changed

+52599
-8
lines changed

4 files changed

+52599
-8
lines changed

deepul/hw3_helper.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def q1_gan_plot(data, samples, xs, ys, title, fname):
4747

4848
def q1_data(n=20000):
4949
assert n % 2 == 0
50-
gaussian1 = np.random.normal(loc=-1.5, scale=0.22, size=(n//2,))
50+
gaussian1 = np.random.normal(loc=-1.5, scale=0.35, size=(n//2,))
5151
gaussian2 = np.random.normal(loc=0.2, scale=0.6, size=(n//2,))
5252
data = (np.concatenate([gaussian1, gaussian2]) + 1).reshape([-1, 1])
5353
scaled_data = (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-8)
@@ -160,16 +160,14 @@ def save_plot(
160160

161161
def q3_save_results(fn, part):
162162
train_data, test_data = load_q3_data()
163-
gan_losses, optional_lpips_losses, l2_train_losses, l2_val_losses, recon_show, recon_is = fn(train_data, test_data, test_data[:100])
163+
gan_losses, lpips_losses, l2_train_losses, l2_val_losses, recon_show = fn(train_data, test_data, test_data[:100])
164164

165-
plot_gan_training(gan_losses, f'Q3{part} Losses', f'results/q3{part}_gan_losses.png')
165+
plot_gan_training(gan_losses, f'Q3{part} Discriminator Losses', f'results/q3{part}_gan_losses.png')
166166
save_plot(l2_train_losses, l2_val_losses, f'Q3{part} L2 Losses', f'results/q3{part}_l2_losses.png')
167-
if optional_lpips_losses is not None:
168-
save_plot(optional_lpips_losses, None, f'Q3{part} LPIPS Losses', f'results/q3{part}_lpips_losses.png')
167+
save_plot(lpips_losses, None, f'Q3{part} LPIPS Losses', f'results/q3{part}_lpips_losses.png')
169168
show_samples(test_data[:100].transpose(0, 2, 3, 1) * 255.0, nrow=20, fname=f'results/q3{part}_data_samples.png', title=f'Q3{part} CIFAR10 val samples')
170169
show_samples(recon_show * 255.0, nrow=20, fname=f'results/q3{part}_reconstructions.png', title=f'Q3{part} VQGAN reconstructions')
171-
print('inception score:', calculate_is(recon_is.transpose([0, 2, 3, 1])))
172-
print('final_reconstruction_loss:', l2_val_losses[-1])
170+
print('final_val_reconstruction_loss:', l2_val_losses[-1])
173171

174172
######################
175173
##### Question 4 #####
@@ -178,7 +176,7 @@ def q3_save_results(fn, part):
178176
def get_colored_mnist(data):
179177
# from https://www.wouterbulten.nl/blog/tech/getting-started-with-gans-2-colorful-mnist/
180178
# Read Lena image
181-
lena = PILImage.open('deepul/deepul/hw4_utils/lena.jpg')
179+
lena = PILImage.open('deepul/deepul/hw3_utils/lena.jpg')
182180

183181
# Resize
184182
batch_resized = np.asarray([scipy.ndimage.zoom(image, (2.3, 2.3, 1), order=1) for image in data])

deepul/hw3_utils/lpips.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models
2+
3+
Taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/lpips.py#L11
4+
"""
5+
6+
import hashlib
7+
import os
8+
from collections import namedtuple
9+
10+
import requests
11+
import torch
12+
import torch.nn as nn
13+
from torchvision import models
14+
from tqdm import tqdm
15+
16+
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
17+
18+
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
19+
20+
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
21+
22+
23+
def download(url, local_path, chunk_size=1024):
24+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
25+
with requests.get(url, stream=True) as r:
26+
total_size = int(r.headers.get("content-length", 0))
27+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
28+
with open(local_path, "wb") as f:
29+
for data in r.iter_content(chunk_size=chunk_size):
30+
if data:
31+
f.write(data)
32+
pbar.update(chunk_size)
33+
34+
35+
def md5_hash(path):
36+
with open(path, "rb") as f:
37+
content = f.read()
38+
return hashlib.md5(content).hexdigest()
39+
40+
41+
def get_ckpt_path(name, root, check=False):
42+
assert name in URL_MAP
43+
path = os.path.join(root, CKPT_MAP[name])
44+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
45+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
46+
download(URL_MAP[name], path)
47+
md5 = md5_hash(path)
48+
assert md5 == MD5_MAP[name], md5
49+
return path
50+
51+
52+
class KeyNotFoundError(Exception):
53+
def __init__(self, cause, keys=None, visited=None):
54+
self.cause = cause
55+
self.keys = keys
56+
self.visited = visited
57+
messages = list()
58+
if keys is not None:
59+
messages.append("Key not found: {}".format(keys))
60+
if visited is not None:
61+
messages.append("Visited: {}".format(visited))
62+
messages.append("Cause:\n{}".format(cause))
63+
message = "\n".join(messages)
64+
super().__init__(message)
65+
66+
67+
def retrieve(
68+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
69+
):
70+
"""Given a nested list or dict return the desired value at key expanding
71+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
72+
is done in-place.
73+
74+
Parameters
75+
----------
76+
list_or_dict : list or dict
77+
Possibly nested list or dictionary.
78+
key : str
79+
key/to/value, path like string describing all keys necessary to
80+
consider to get to the desired value. List indices can also be
81+
passed here.
82+
splitval : str
83+
String that defines the delimiter between keys of the
84+
different depth levels in `key`.
85+
default : obj
86+
Value returned if :attr:`key` is not found.
87+
expand : bool
88+
Whether to expand callable nodes on the path or not.
89+
90+
Returns
91+
-------
92+
The desired value or if :attr:`default` is not ``None`` and the
93+
:attr:`key` is not found returns ``default``.
94+
95+
Raises
96+
------
97+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
98+
``None``.
99+
"""
100+
101+
keys = key.split(splitval)
102+
103+
success = True
104+
try:
105+
visited = []
106+
parent = None
107+
last_key = None
108+
for key in keys:
109+
if callable(list_or_dict):
110+
if not expand:
111+
raise KeyNotFoundError(
112+
ValueError(
113+
"Trying to get past callable node with expand=False."
114+
),
115+
keys=keys,
116+
visited=visited,
117+
)
118+
list_or_dict = list_or_dict()
119+
parent[last_key] = list_or_dict
120+
121+
last_key = key
122+
parent = list_or_dict
123+
124+
try:
125+
if isinstance(list_or_dict, dict):
126+
list_or_dict = list_or_dict[key]
127+
else:
128+
list_or_dict = list_or_dict[int(key)]
129+
except (KeyError, IndexError, ValueError) as e:
130+
raise KeyNotFoundError(e, keys=keys, visited=visited)
131+
132+
visited += [key]
133+
# final expansion of retrieved value
134+
if expand and callable(list_or_dict):
135+
list_or_dict = list_or_dict()
136+
parent[last_key] = list_or_dict
137+
except KeyNotFoundError as e:
138+
if default is None:
139+
raise e
140+
else:
141+
list_or_dict = default
142+
success = False
143+
144+
if not pass_success:
145+
return list_or_dict
146+
else:
147+
return list_or_dict, success
148+
149+
150+
class LPIPS(nn.Module):
151+
# Learned perceptual metric
152+
def __init__(self, use_dropout=True):
153+
super().__init__()
154+
self.scaling_layer = ScalingLayer()
155+
self.chns = [64, 128, 256, 512, 512] # vg16 features
156+
self.net = vgg16(pretrained=True, requires_grad=False)
157+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
158+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
159+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
160+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
161+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
162+
self.load_from_pretrained()
163+
for param in self.parameters():
164+
param.requires_grad = False
165+
166+
def load_from_pretrained(self, name="vgg_lpips"):
167+
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
168+
self.load_state_dict(
169+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
170+
)
171+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
172+
173+
@classmethod
174+
def from_pretrained(cls, name="vgg_lpips"):
175+
if name != "vgg_lpips":
176+
raise NotImplementedError
177+
model = cls()
178+
ckpt = get_ckpt_path(name)
179+
model.load_state_dict(
180+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
181+
)
182+
return model
183+
184+
def forward(self, input, target):
185+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
186+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
187+
feats0, feats1, diffs = {}, {}, {}
188+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
189+
for kk in range(len(self.chns)):
190+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
191+
outs1[kk]
192+
)
193+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
194+
195+
res = [
196+
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
197+
for kk in range(len(self.chns))
198+
]
199+
val = res[0]
200+
for l in range(1, len(self.chns)):
201+
val += res[l]
202+
return val
203+
204+
205+
class ScalingLayer(nn.Module):
206+
def __init__(self):
207+
super(ScalingLayer, self).__init__()
208+
self.register_buffer(
209+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
210+
)
211+
self.register_buffer(
212+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
213+
)
214+
215+
def forward(self, inp):
216+
return (inp - self.shift) / self.scale
217+
218+
219+
class NetLinLayer(nn.Module):
220+
"""A single linear layer which does a 1x1 conv"""
221+
222+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
223+
super(NetLinLayer, self).__init__()
224+
layers = (
225+
[
226+
nn.Dropout(),
227+
]
228+
if (use_dropout)
229+
else []
230+
)
231+
layers += [
232+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
233+
]
234+
self.model = nn.Sequential(*layers)
235+
236+
237+
class vgg16(torch.nn.Module):
238+
def __init__(self, requires_grad=False, pretrained=True):
239+
super(vgg16, self).__init__()
240+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
241+
self.slice1 = torch.nn.Sequential()
242+
self.slice2 = torch.nn.Sequential()
243+
self.slice3 = torch.nn.Sequential()
244+
self.slice4 = torch.nn.Sequential()
245+
self.slice5 = torch.nn.Sequential()
246+
self.N_slices = 5
247+
for x in range(4):
248+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
249+
for x in range(4, 9):
250+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
251+
for x in range(9, 16):
252+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
253+
for x in range(16, 23):
254+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
255+
for x in range(23, 30):
256+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
257+
if not requires_grad:
258+
for param in self.parameters():
259+
param.requires_grad = False
260+
261+
def forward(self, X):
262+
h = self.slice1(X)
263+
h_relu1_2 = h
264+
h = self.slice2(h)
265+
h_relu2_2 = h
266+
h = self.slice3(h)
267+
h_relu3_3 = h
268+
h = self.slice4(h)
269+
h_relu4_3 = h
270+
h = self.slice5(h)
271+
h_relu5_3 = h
272+
vgg_outputs = namedtuple(
273+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
274+
)
275+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
276+
return out
277+
278+
279+
def normalize_tensor(x, eps=1e-10):
280+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
281+
return x / (norm_factor + eps)
282+
283+
284+
def spatial_average(x, keepdim=True):
285+
return x.mean([2, 3], keepdim=keepdim)

0 commit comments

Comments
 (0)