Skip to content

Commit e0cf791

Browse files
committed
Support for different radiuses for the deconvolution input tensor
Also append radius to the checkpoint directory name.
1 parent 286e6cf commit e0cf791

File tree

4 files changed

+35
-36
lines changed

4 files changed

+35
-36
lines changed

main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
flags.DEFINE_float("learning_rate", 1e-4, "The learning rate of the adam optimizer [1e-4]")
1414
flags.DEFINE_integer("c_dim", 1, "Dimension of image color [1]")
1515
flags.DEFINE_integer("scale", 2, "The size of scale factor for preprocessing input image [2]")
16+
flags.DEFINE_integer("radius", 1, "Max radius of the deconvolution input tensor [1]")
1617
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Name of checkpoint directory [checkpoint]")
1718
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
1819
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")

model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,16 @@ def __init__(self, sess, config):
3030
self.is_grayscale = (self.c_dim == 1)
3131
self.epoch = config.epoch
3232
self.scale = config.scale
33+
self.radius = config.radius
3334
self.batch_size = config.batch_size
3435
self.learning_rate = config.learning_rate
3536
self.threads = config.threads
3637
self.distort = config.distort
3738
self.params = config.params
3839

40+
self.padding = 4
3941
# Different image/label sub-sizes for different scaling factors x2, x3, x4
40-
scale_factors = [[24, 40], [18, 42], [16, 48]]
42+
scale_factors = [[20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]]
4143
self.image_size, self.label_size = scale_factors[self.scale - 2]
4244
# Testing uses different strides to ensure sub-images line up correctly
4345
if not self.train:
@@ -48,8 +50,6 @@ def __init__(self, sess, config):
4850
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper
4951
model_params = [[56, 12, 4], [32, 8, 1]]
5052
self.model_params = model_params[self.fast]
51-
52-
self.deconv_radius = [3, 5, 7][self.scale - 2]
5353

5454
self.checkpoint_dir = config.checkpoint_dir
5555
self.output_dir = config.output_dir
@@ -169,7 +169,8 @@ def model(self):
169169
d, s, m = self.model_params
170170

171171
# Feature Extraction
172-
self.weights['w1'] = tf.get_variable('w1', initializer=tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32))
172+
size = self.radius * 2 + 1
173+
self.weights['w1'] = tf.get_variable('w1', initializer=tf.random_normal([size, size, 1, d], stddev=0.0378, dtype=tf.float32))
173174
self.biases['b1'] = tf.get_variable('b1', initializer=tf.zeros([d]))
174175
conv = self.prelu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'], 1)
175176

@@ -196,7 +197,7 @@ def model(self):
196197
conv = self.prelu(tf.nn.conv2d(conv, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, m + 3)
197198

198199
# Deconvolution
199-
deconv_size = self.deconv_radius * 2 + 1
200+
deconv_size = self.radius * self.scale * 2 + 1
200201
deconv_weights = tf.get_variable('w{}'.format(m + 4), initializer=tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32))
201202
deconv_biases = tf.get_variable('b{}'.format(m + 4), initializer=tf.zeros([1]))
202203
self.weights['w{}'.format(m + 4)], self.biases['b{}'.format(m + 4)] = deconv_weights, deconv_biases
@@ -220,7 +221,7 @@ def prelu(self, _x, i):
220221
def save(self, checkpoint_dir, step):
221222
model_name = "FSRCNN.model"
222223
d, s, m = self.model_params
223-
model_dir = "%s_%s_%s-%s-%s" % ("fsrcnn", self.label_size, d, s, m)
224+
model_dir = "%s_%s_%s-%s-%s_%s" % ("fsrcnn", self.label_size, d, s, m, "r"+str(self.radius))
224225
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
225226

226227
if not os.path.exists(checkpoint_dir):
@@ -233,7 +234,7 @@ def save(self, checkpoint_dir, step):
233234
def load(self, checkpoint_dir):
234235
print(" [*] Reading checkpoints...")
235236
d, s, m = self.model_params
236-
model_dir = "%s_%s_%s-%s-%s" % ("fsrcnn", self.label_size, d, s, m)
237+
model_dir = "%s_%s_%s-%s-%s_%s" % ("fsrcnn", self.label_size, d, s, m, "r"+str(self.radius))
237238
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
238239

239240
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

sort.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,29 @@
22

33
def main():
44
scale = 2
5-
radius = [3, 5, 7][scale-2]
5+
radius = 2
6+
size = radius * scale * 2 + 1
67
d = 64 #size of the feature layer
8+
79
if len(sys.argv) == 2:
810
fname=sys.argv[1]
911
with open(fname) as f:
1012
content = f.readlines()
1113
content = [x.strip() for x in content]
1214

15+
x=list(reversed(range(scale)))
16+
x=x[-1:]+x[:-1]
1317
xy = []
14-
for i in range(0, scale):
15-
for j in range(0, scale):
18+
for i in x:
19+
for j in x:
1620
xy.append([j, i])
17-
xy = list(reversed(xy))
1821

1922
m = []
2023
for i in range(0, len(xy)):
2124
xi, yi = xy[i]
22-
for x in range(xi, radius*2+1, scale):
23-
for y in range(yi, radius*2+1, scale):
24-
m.append(y + x*(radius*2+1))
25+
for y in range(yi, size, scale):
26+
for x in range(xi, size, scale):
27+
m.append(y + x * size)
2528
#print(m)
2629
content = list(reversed(content))
2730
sort = [content[m[l]].strip(",") for l in range(0, len(m))]

utils.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,9 @@ def modcrop(image, scale=3):
120120

121121
def train_input_worker(args):
122122
image_data, config = args
123-
image_size, label_size, stride, scale, distort = config
123+
image_size, label_size, stride, scale, padding, distort = config
124124

125125
single_input_sequence, single_label_sequence = [], []
126-
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
127-
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
128126

129127
input_, label_ = preprocess(image_data, scale, distort=distort)
130128

@@ -133,10 +131,10 @@ def train_input_worker(args):
133131
else:
134132
h, w = input_.shape
135133

136-
for x in range(0, h - image_size - padding + 1, stride):
137-
for y in range(0, w - image_size - padding + 1, stride):
138-
sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
139-
x_loc, y_loc = x + label_padding, y + label_padding
134+
for x in range(0, h - image_size + 1, stride):
135+
for y in range(0, w - image_size + 1, stride):
136+
sub_input = input_[x : x + image_size, y : y + image_size]
137+
x_loc, y_loc = x + padding, y + padding
140138
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
141139

142140
sub_input = sub_input.reshape([image_size, image_size, 1])
@@ -165,7 +163,7 @@ def thread_train_setup(config):
165163
pool = Pool(config.threads)
166164

167165
# Distribute |images_per_thread| images across each worker process
168-
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.distort]
166+
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
169167
images_per_thread = len(data) // config.threads
170168
workers = []
171169
for thread in range(config.threads):
@@ -202,14 +200,12 @@ def train_input_setup(config):
202200
Read image files, make their sub-images, and save them as a h5 file format.
203201
"""
204202
sess = config.sess
205-
image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
203+
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2
206204

207205
# Load data path
208206
data = prepare_data(sess, dataset=config.data_dir)
209207

210208
sub_input_sequence, sub_label_sequence = [], []
211-
padding = abs(image_size - label_size) // 2 # eg. for 3x: (21 - 11) / 2 = 5
212-
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
213209

214210
for i in range(len(data)):
215211
input_, label_ = preprocess(data[i], scale, distort=config.distort)
@@ -219,10 +215,10 @@ def train_input_setup(config):
219215
else:
220216
h, w = input_.shape
221217

222-
for x in range(0, h - image_size - padding + 1, stride):
223-
for y in range(0, w - image_size - padding + 1, stride):
224-
sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
225-
x_loc, y_loc = x + label_padding, y + label_padding
218+
for x in range(0, h - image_size + 1, stride):
219+
for y in range(0, w - image_size + 1, stride):
220+
sub_input = input_[x : x + image_size, y : y + image_size]
221+
x_loc, y_loc = x + padding, y + padding
226222
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
227223

228224
sub_input = sub_input.reshape([image_size, image_size, 1])
@@ -242,14 +238,12 @@ def test_input_setup(config):
242238
Read image files, make their sub-images, and save them as a h5 file format.
243239
"""
244240
sess = config.sess
245-
image_size, label_size, stride, scale = config.image_size, config.label_size, config.stride, config.scale
241+
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2
246242

247243
# Load data path
248244
data = prepare_data(sess, dataset="Test")
249245

250246
sub_input_sequence, sub_label_sequence = [], []
251-
padding = abs(image_size - label_size) // 2 # eg. (21 - 11) / 2 = 5
252-
label_padding = abs((image_size - 4) - label_size) // 2 # eg. for 3x: (21 - (11 - 4)) / 2 = 7
253247

254248
pic_index = 2 # Index of image based on lexicographic order in data folder
255249
input_, label_ = preprocess(data[pic_index], config.scale)
@@ -260,13 +254,13 @@ def test_input_setup(config):
260254
h, w = input_.shape
261255

262256
nx, ny = 0, 0
263-
for x in range(0, h - image_size - padding + 1, stride):
257+
for x in range(0, h - image_size + 1, stride):
264258
nx += 1
265259
ny = 0
266-
for y in range(0, w - image_size - padding + 1, stride):
260+
for y in range(0, w - image_size + 1, stride):
267261
ny += 1
268-
sub_input = input_[x + padding : x + padding + image_size, y + padding : y + padding + image_size]
269-
x_loc, y_loc = x + label_padding, y + label_padding
262+
sub_input = input_[x : x + image_size, y : y + image_size]
263+
x_loc, y_loc = x + padding, y + padding
270264
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
271265

272266
sub_input = sub_input.reshape([image_size, image_size, 1])

0 commit comments

Comments
 (0)