Skip to content

Commit 7aeffac

Browse files
committed
add hw3 helpers
1 parent 0cfe35b commit 7aeffac

File tree

3 files changed

+344
-0
lines changed

3 files changed

+344
-0
lines changed

deepul/hw3_helper.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import numpy as np
2+
import torch.nn as nn
3+
import torch.utils.data
4+
import torchvision
5+
from torchvision import transforms as transforms
6+
from .utils import *
7+
from .hw3_utils.hw3_models import GoogLeNet
8+
from PIL import Image as PILImage
9+
import scipy.ndimage
10+
import cv2
11+
import deepul.pytorch_util as ptu
12+
13+
import numpy as np
14+
import math
15+
import sys
16+
17+
softmax = None
18+
model = None
19+
device = torch.device("cuda:0")
20+
21+
def plot_gan_training(losses, title, fname):
22+
plt.figure()
23+
n_itr = len(losses)
24+
xs = np.arange(n_itr)
25+
26+
plt.plot(xs, losses, label='loss')
27+
plt.legend()
28+
plt.title(title)
29+
plt.xlabel('Training Iteration')
30+
plt.ylabel('Loss')
31+
savefig(fname)
32+
33+
def q1_gan_plot(data, samples, xs, ys, title, fname):
34+
plt.figure()
35+
plt.hist(samples, bins=50, density=True, alpha=0.7, label='fake')
36+
plt.hist(data, bins=50, density=True, alpha=0.7, label='real')
37+
38+
plt.plot(xs, ys, label='discrim')
39+
plt.legend()
40+
plt.title(title)
41+
savefig(fname)
42+
43+
44+
######################
45+
##### Question 1 #####
46+
######################
47+
48+
def q1_data(n=20000):
49+
assert n % 2 == 0
50+
gaussian1 = np.random.normal(loc=-1.5, scale=0.22, size=(n//2,))
51+
gaussian2 = np.random.normal(loc=0.2, scale=0.6, size=(n//2,))
52+
data = (np.concatenate([gaussian1, gaussian2]) + 1).reshape([-1, 1])
53+
scaled_data = (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-8)
54+
return 2 * scaled_data -1
55+
56+
def visualize_q1_dataset():
57+
data = q1_data()
58+
plt.hist(data, bins=50, alpha=0.7, label='train data')
59+
plt.legend()
60+
plt.show()
61+
62+
63+
def q1_save_results(part, fn):
64+
data = q1_data()
65+
losses, samples1, xs1, ys1, samples_end, xs_end, ys_end = fn(data)
66+
67+
# loss plot
68+
plot_gan_training(losses, 'Q1{} Losses'.format(part), 'results/q1{}_losses.png'.format(part))
69+
70+
# samples
71+
q1_gan_plot(data, samples1, xs1, ys1, 'Q1{} Epoch 1'.format(part), 'results/q1{}_epoch1.png'.format(part))
72+
q1_gan_plot(data, samples_end, xs_end, ys_end, 'Q1{} Final'.format(part), 'results/q1{}_final.png'.format(part))
73+
74+
######################
75+
##### Question 2 #####
76+
######################
77+
78+
def calculate_is(samples):
79+
assert (type(samples[0]) == np.ndarray)
80+
assert (len(samples[0].shape) == 3)
81+
82+
model = GoogLeNet().to(ptu.device)
83+
model.load_state_dict(torch.load("deepul/deepul/hw4_utils/classifier.pt"))
84+
softmax = nn.Sequential(model, nn.Softmax(dim=1))
85+
86+
bs = 100
87+
softmax.eval()
88+
with torch.no_grad():
89+
preds = []
90+
n_batches = int(math.ceil(float(len(samples)) / float(bs)))
91+
for i in range(n_batches):
92+
sys.stdout.write(".")
93+
sys.stdout.flush()
94+
inp = ptu.FloatTensor(samples[(i * bs):min((i + 1) * bs, len(samples))])
95+
pred = ptu.get_numpy(softmax(inp))
96+
preds.append(pred)
97+
preds = np.concatenate(preds, 0)
98+
kl = preds * (np.log(preds) - np.log(np.expand_dims(np.mean(preds, 0), 0)))
99+
kl = np.mean(np.sum(kl, 1))
100+
return np.exp(kl)
101+
102+
def load_q2_data():
103+
train_data = torchvision.datasets.CIFAR10("./data", transform=torchvision.transforms.ToTensor(),
104+
download=True, train=True)
105+
return train_data
106+
107+
def visualize_q2_data():
108+
train_data = load_q2_data()
109+
imgs = train_data.data[:100]
110+
show_samples(imgs, title=f'CIFAR-10 Samples')
111+
112+
def q2_save_results(fn):
113+
train_data = load_q2_data()
114+
train_data = train_data.data.transpose((0, 3, 1, 2)) / 255.0
115+
train_losses, samples = fn(train_data)
116+
117+
print("Inception score:", calculate_is(samples.transpose([0, 3, 1, 2])))
118+
plot_gan_training(train_losses, 'Q2 Losses', 'results/q2_losses.png')
119+
show_samples(samples[:100] * 255.0, fname='results/q2_samples.png', title=f'CIFAR-10 generated samples')
120+
121+
######################
122+
##### Question 3 #####
123+
######################
124+
125+
def load_q3_data():
126+
transform = transforms.Compose([
127+
transforms.ToTensor(),
128+
transforms.Normalize((0.5,), (0.5,))
129+
])
130+
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform).data.transpose((0, 3, 1, 2)) / 255.0
131+
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform).data.transpose((0, 3, 1, 2)) / 255.0
132+
return train_data, test_data
133+
134+
135+
def visualize_q3_data():
136+
train_data, _ = load_q3_data()
137+
imgs = train_data.data[:100]
138+
show_samples(imgs.reshape([100, 28, 28, 1]) * 255.0, title='CIFAR10 samples')
139+
140+
def save_plot(
141+
train_losses: np.ndarray, test_losses: np.ndarray, title: str, fname: str
142+
) -> None:
143+
plt.figure()
144+
if test_losses is None:
145+
plt.plot(train_losses, label="train")
146+
plt.xlabel("Iteration")
147+
else:
148+
n_epochs = len(test_losses) - 1
149+
x_train = np.linspace(0, n_epochs, len(train_losses))
150+
x_test = np.arange(n_epochs + 1)
151+
152+
plt.plot(x_train, train_losses, label="train")
153+
plt.plot(x_test, test_losses, label="test")
154+
plt.xlabel("Epoch")
155+
plt.legend()
156+
plt.title(title)
157+
plt.ylabel("loss")
158+
savefig(fname)
159+
160+
161+
def q3_save_results(fn, part):
162+
train_data, test_data = load_q3_data()
163+
gan_losses, optional_lpips_losses, l2_train_losses, l2_val_losses, recon_show, recon_is = fn(train_data, test_data, test_data[:100])
164+
165+
plot_gan_training(gan_losses, f'Q3{part} Losses', f'results/q3{part}_gan_losses.png')
166+
save_plot(l2_train_losses, l2_val_losses, f'Q3{part} L2 Losses', f'results/q3{part}_l2_losses.png')
167+
if optional_lpips_losses is not None:
168+
save_plot(optional_lpips_losses, None, f'Q3{part} LPIPS Losses', f'results/q3{part}_lpips_losses.png')
169+
show_samples(test_data[:100].transpose(0, 2, 3, 1) * 255.0, nrow=20, fname=f'results/q3{part}_data_samples.png', title=f'Q3{part} CIFAR10 val samples')
170+
show_samples(recon_show * 255.0, nrow=20, fname=f'results/q3{part}_reconstructions.png', title=f'Q3{part} VQGAN reconstructions')
171+
print('inception score:', calculate_is(recon_is.transpose([0, 2, 3, 1])))
172+
print('final_reconstruction_loss:', l2_val_losses[-1])
173+
174+
######################
175+
##### Question 4 #####
176+
######################
177+
178+
def get_colored_mnist(data):
179+
# from https://www.wouterbulten.nl/blog/tech/getting-started-with-gans-2-colorful-mnist/
180+
# Read Lena image
181+
lena = PILImage.open('deepul/deepul/hw4_utils/lena.jpg')
182+
183+
# Resize
184+
batch_resized = np.asarray([scipy.ndimage.zoom(image, (2.3, 2.3, 1), order=1) for image in data])
185+
186+
# Extend to RGB
187+
batch_rgb = np.concatenate([batch_resized, batch_resized, batch_resized], axis=3)
188+
189+
# Make binary
190+
batch_binary = (batch_rgb > 0.5)
191+
192+
batch = np.zeros((data.shape[0], 28, 28, 3))
193+
194+
for i in range(data.shape[0]):
195+
# Take a random crop of the Lena image (background)
196+
x_c = np.random.randint(0, lena.size[0] - 64)
197+
y_c = np.random.randint(0, lena.size[1] - 64)
198+
image = lena.crop((x_c, y_c, x_c + 64, y_c + 64))
199+
image = np.asarray(image) / 255.0
200+
201+
# Invert the colors at the location of the number
202+
image[batch_binary[i]] = 1 - image[batch_binary[i]]
203+
204+
batch[i] = cv2.resize(image, (0, 0), fx=28 / 64, fy=28 / 64, interpolation=cv2.INTER_AREA)
205+
return batch.transpose(0, 3, 1, 2)
206+
207+
def load_q4_data():
208+
train, _ = load_q3_data()
209+
mnist = np.array(train.data.reshape(-1, 28, 28, 1) / 255.0)
210+
colored_mnist = get_colored_mnist(mnist)
211+
return mnist.transpose(0, 3, 1, 2), colored_mnist
212+
213+
def visualize_cyclegan_datasets():
214+
mnist, colored_mnist = load_q4_data()
215+
mnist, colored_mnist = mnist[:100], colored_mnist[:100]
216+
show_samples(mnist.reshape([100, 28, 28, 1]) * 255.0, title=f'MNIST samples')
217+
show_samples(colored_mnist.transpose([0, 2, 3, 1]) * 255.0, title=f'Colored MNIST samples')
218+
219+
def q4_save_results(fn):
220+
mnist, cmnist = load_q4_data()
221+
222+
m1, c1, m2, c2, m3, c3 = fn(mnist, cmnist)
223+
m1, m2, m3 = m1.repeat(3, axis=3), m2.repeat(3, axis=3), m3.repeat(3, axis=3)
224+
mnist_reconstructions = np.concatenate([m1, c1, m2], axis=0)
225+
colored_mnist_reconstructions = np.concatenate([c2, m3, c3], axis=0)
226+
227+
show_samples(mnist_reconstructions * 255.0, nrow=20,
228+
fname='figures/q4_mnist.png',
229+
title=f'Source domain: MNIST')
230+
show_samples(colored_mnist_reconstructions * 255.0, nrow=20,
231+
fname='figures/q4_colored_mnist.png',
232+
title=f'Source domain: Colored MNIST')

deepul/hw3_utils/__init__.py

Whitespace-only changes.

deepul/hw3_utils/hw3_models.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class Inception(nn.Module):
6+
def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
7+
super(Inception, self).__init__()
8+
# 1x1 conv branch
9+
self.b1 = nn.Sequential(
10+
nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
11+
nn.BatchNorm2d(kernel_1_x),
12+
nn.ReLU(True),
13+
)
14+
15+
# 1x1 conv -> 3x3 conv branch
16+
self.b2 = nn.Sequential(
17+
nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
18+
nn.BatchNorm2d(kernel_3_in),
19+
nn.ReLU(True),
20+
nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
21+
nn.BatchNorm2d(kernel_3_x),
22+
nn.ReLU(True),
23+
)
24+
25+
# 1x1 conv -> 5x5 conv branch
26+
self.b3 = nn.Sequential(
27+
nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
28+
nn.BatchNorm2d(kernel_5_in),
29+
nn.ReLU(True),
30+
nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
31+
nn.BatchNorm2d(kernel_5_x),
32+
nn.ReLU(True),
33+
nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
34+
nn.BatchNorm2d(kernel_5_x),
35+
nn.ReLU(True),
36+
)
37+
38+
# 3x3 pool -> 1x1 conv branch
39+
self.b4 = nn.Sequential(
40+
nn.MaxPool2d(3, stride=1, padding=1),
41+
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
42+
nn.BatchNorm2d(pool_planes),
43+
nn.ReLU(True),
44+
)
45+
46+
def forward(self, x):
47+
y1 = self.b1(x)
48+
y2 = self.b2(x)
49+
y3 = self.b3(x)
50+
y4 = self.b4(x)
51+
return torch.cat([y1,y2,y3,y4], 1)
52+
53+
54+
class GoogLeNet(nn.Module):
55+
def __init__(self):
56+
super(GoogLeNet, self).__init__()
57+
self.pre_layers = nn.Sequential(
58+
nn.Conv2d(3, 192, kernel_size=3, padding=1),
59+
nn.BatchNorm2d(192),
60+
nn.ReLU(True),
61+
)
62+
63+
self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
64+
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
65+
66+
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
67+
68+
self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
69+
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
70+
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
71+
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
72+
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
73+
74+
self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
75+
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
76+
77+
self.avgpool = nn.AvgPool2d(8, stride=1)
78+
self.linear = nn.Linear(1024, 10)
79+
80+
def forward(self, x):
81+
x = self.pre_layers(x)
82+
x = self.a3(x)
83+
x = self.b3(x)
84+
x = self.max_pool(x)
85+
x = self.a4(x)
86+
x = self.b4(x)
87+
x = self.c4(x)
88+
x = self.d4(x)
89+
x = self.e4(x)
90+
x = self.max_pool(x)
91+
x = self.a5(x)
92+
x = self.b5(x)
93+
x = self.avgpool(x)
94+
x = x.view(x.size(0), -1)
95+
x = self.linear(x)
96+
return x
97+
98+
def forward_fid(self, x):
99+
x = self.pre_layers(x)
100+
x = self.a3(x)
101+
x = self.b3(x)
102+
x = self.max_pool(x)
103+
x = self.a4(x)
104+
x = self.b4(x)
105+
x = self.c4(x)
106+
x = self.d4(x)
107+
x = self.e4(x)
108+
x = self.max_pool(x)
109+
x = self.a5(x)
110+
x = self.b5(x)
111+
x = self.avgpool(x)
112+
return x

0 commit comments

Comments
 (0)